@@ -131,7 +131,10 @@ def __add__(self, other):
131131 return ConcatLibrary ([self , other ])
132132
133133 def __mul__ (self , other ):
134- return TensoredLibrary ([self , other ])
134+ if isinstance (self , TensoredLibrary ):
135+ return TensoredLibrary (self .libraries + [other ])
136+ else :
137+ return TensoredLibrary ([self , other ])
135138
136139 def __rmul__ (self , other ):
137140 return TensoredLibrary ([self , other ])
@@ -451,30 +454,19 @@ def transform(self, x_full):
451454
452455 xp_full = []
453456 for x in x_full :
454- xp = []
455- for i in range (len (self .libraries )):
456- lib_i = self .libraries [i ]
457- if self .inputs_per_library is None :
458- xp_i = lib_i .transform ([x ])[0 ]
459- else :
460- xp_i = lib_i .transform (
461- [x [..., _unique (self .inputs_per_library [i ])]]
462- )[0 ]
463-
464- for j in range (i + 1 , len (self .libraries )):
465- lib_j = self .libraries [j ]
466- xp_j = lib_j .transform (
467- [x [..., _unique (self .inputs_per_library [j ])]]
468- )[0 ]
469-
470- xp .append (self ._combinations (xp_i , xp_j ))
471-
472- xp = np .concatenate (xp , axis = xp [0 ].ax_coord )
473- xp = AxesArray (xp , comprehend_axes (xp ))
457+ xp = self .libraries [0 ].transform (
458+ [x [..., _unique (self .inputs_per_library [0 ])]]
459+ )[0 ]
460+ for i in range (1 , len (self .libraries )):
461+ xp_i = self .libraries [i ].transform (
462+ [x [..., _unique (self .inputs_per_library [i ])]]
463+ )[0 ]
464+ xp = self ._combinations (xp , xp_i )
465+
474466 xp_full .append (xp )
475467 return xp_full
476468
477- def get_feature_names (self , input_features = None ):
469+ def get_feature_names (self , input_features : list [ str ] | None = None ) -> list [ str ] :
478470 """Return feature names for output features.
479471
480472 Parameters
@@ -487,32 +479,22 @@ def get_feature_names(self, input_features=None):
487479 -------
488480 output_feature_names : list of string, length n_output_features
489481 """
490- feature_names = list ()
491- for i in range (len (self .libraries )):
492- lib_i = self .libraries [i ]
493- if input_features is None :
494- input_features_i = [
495- "x%d" % k for k in _unique (self .inputs_per_library [i ])
496- ]
497- else :
498- input_features_i = np .asarray (input_features )[
499- _unique (self .inputs_per_library [i ])
500- ].tolist ()
501- lib_i_feat_names = lib_i .get_feature_names (input_features_i )
502- for j in range (i + 1 , len (self .libraries )):
503- lib_j = self .libraries [j ]
504- if input_features is None :
505- input_features_j = [
506- "x%d" % k for k in _unique (self .inputs_per_library [j ])
507- ]
508- else :
509- input_features_j = np .asarray (input_features )[
510- _unique (self .inputs_per_library [j ])
511- ].tolist ()
512- lib_j_feat_names = lib_j .get_feature_names (input_features_j )
513- feature_names += self ._name_combinations (
514- lib_i_feat_names , lib_j_feat_names
515- )
482+ check_is_fitted (self )
483+
484+ if input_features is None :
485+ input_features = ["x%d" % i for i in range (self .n_features_in_ )]
486+
487+ feature_names = self .libraries [0 ].get_feature_names (
488+ np .asarray (input_features )[_unique (self .inputs_per_library [0 ])].tolist ()
489+ )
490+
491+ for i in range (1 , len (self .libraries )):
492+ cur_input_features = np .asarray (input_features )[
493+ _unique (self .inputs_per_library [i ])
494+ ].tolist ()
495+ lib_i_feat_names = self .libraries [i ].get_feature_names (cur_input_features )
496+ feature_names = self ._name_combinations (feature_names , lib_i_feat_names )
497+
516498 return feature_names
517499
518500 def calc_trajectory (self , diff_method , x , t ):
0 commit comments