Skip to content

Commit 0d5f6cf

Browse files
fix: Flatten libraries when tensoring a TensorLibrary
Fixes #540 Previously, when `TensoredLibrary.libraries` included another `TensoredLibrary`, they could not correctly assign `inputs_per_library`. Also adds `test_tensored_library`
1 parent 816382a commit 0d5f6cf

File tree

2 files changed

+55
-48
lines changed

2 files changed

+55
-48
lines changed

pysindy/feature_library/base.py

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,10 @@ def __add__(self, other):
131131
return ConcatLibrary([self, other])
132132

133133
def __mul__(self, other):
134-
return TensoredLibrary([self, other])
134+
if isinstance(self, TensoredLibrary):
135+
return TensoredLibrary(self.libraries + [other])
136+
else:
137+
return TensoredLibrary([self, other])
135138

136139
def __rmul__(self, other):
137140
return TensoredLibrary([self, other])
@@ -451,30 +454,19 @@ def transform(self, x_full):
451454

452455
xp_full = []
453456
for x in x_full:
454-
xp = []
455-
for i in range(len(self.libraries)):
456-
lib_i = self.libraries[i]
457-
if self.inputs_per_library is None:
458-
xp_i = lib_i.transform([x])[0]
459-
else:
460-
xp_i = lib_i.transform(
461-
[x[..., _unique(self.inputs_per_library[i])]]
462-
)[0]
463-
464-
for j in range(i + 1, len(self.libraries)):
465-
lib_j = self.libraries[j]
466-
xp_j = lib_j.transform(
467-
[x[..., _unique(self.inputs_per_library[j])]]
468-
)[0]
469-
470-
xp.append(self._combinations(xp_i, xp_j))
471-
472-
xp = np.concatenate(xp, axis=xp[0].ax_coord)
473-
xp = AxesArray(xp, comprehend_axes(xp))
457+
xp = self.libraries[0].transform(
458+
[x[..., _unique(self.inputs_per_library[0])]]
459+
)[0]
460+
for i in range(1, len(self.libraries)):
461+
xp_i = self.libraries[i].transform(
462+
[x[..., _unique(self.inputs_per_library[i])]]
463+
)[0]
464+
xp = self._combinations(xp, xp_i)
465+
474466
xp_full.append(xp)
475467
return xp_full
476468

477-
def get_feature_names(self, input_features=None):
469+
def get_feature_names(self, input_features: list[str] | None = None) -> list[str]:
478470
"""Return feature names for output features.
479471
480472
Parameters
@@ -487,32 +479,22 @@ def get_feature_names(self, input_features=None):
487479
-------
488480
output_feature_names : list of string, length n_output_features
489481
"""
490-
feature_names = list()
491-
for i in range(len(self.libraries)):
492-
lib_i = self.libraries[i]
493-
if input_features is None:
494-
input_features_i = [
495-
"x%d" % k for k in _unique(self.inputs_per_library[i])
496-
]
497-
else:
498-
input_features_i = np.asarray(input_features)[
499-
_unique(self.inputs_per_library[i])
500-
].tolist()
501-
lib_i_feat_names = lib_i.get_feature_names(input_features_i)
502-
for j in range(i + 1, len(self.libraries)):
503-
lib_j = self.libraries[j]
504-
if input_features is None:
505-
input_features_j = [
506-
"x%d" % k for k in _unique(self.inputs_per_library[j])
507-
]
508-
else:
509-
input_features_j = np.asarray(input_features)[
510-
_unique(self.inputs_per_library[j])
511-
].tolist()
512-
lib_j_feat_names = lib_j.get_feature_names(input_features_j)
513-
feature_names += self._name_combinations(
514-
lib_i_feat_names, lib_j_feat_names
515-
)
482+
check_is_fitted(self)
483+
484+
if input_features is None:
485+
input_features = ["x%d" % i for i in range(self.n_features_in_)]
486+
487+
feature_names = self.libraries[0].get_feature_names(
488+
np.asarray(input_features)[_unique(self.inputs_per_library[0])].tolist()
489+
)
490+
491+
for i in range(1, len(self.libraries)):
492+
cur_input_features = np.asarray(input_features)[
493+
_unique(self.inputs_per_library[i])
494+
].tolist()
495+
lib_i_feat_names = self.libraries[i].get_feature_names(cur_input_features)
496+
feature_names = self._name_combinations(feature_names, lib_i_feat_names)
497+
516498
return feature_names
517499

518500
def calc_trajectory(self, diff_method, x, t):

test/test_feature_library.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,31 @@ def test_not_fitted(data_lorenz, library):
423423
library.transform(x)
424424

425425

426+
def test_tensored_library():
427+
f_list = [lambda x: x]
428+
lib_names_a = [lambda x: f"a({x})"]
429+
lib_names_b = [lambda x: f"b({x})"]
430+
lib_names_c = [lambda x: f"c({x})"]
431+
432+
lib_a = CustomLibrary(f_list, lib_names_a)
433+
lib_b = CustomLibrary(f_list, lib_names_b)
434+
lib_c = CustomLibrary(f_list, lib_names_c)
435+
436+
libraries = [lib_a, lib_b, lib_c]
437+
438+
inputs_per_library = [[1], [1], [0]]
439+
x = np.array([[1, 2]])
440+
441+
tensored_library = np.prod(np.asarray(libraries))
442+
tensored_library._set_inputs_per_library(inputs_per_library)
443+
444+
values = tensored_library.fit_transform(x)
445+
names = tensored_library.get_feature_names(input_features=["x", "y"])
446+
447+
assert names == ["a(y) b(y) c(x)"]
448+
assert values.item() == 4
449+
450+
426451
def test_generalized_library(data_lorenz):
427452
x, t = data_lorenz
428453
poly_library = PolynomialLibrary(include_bias=False)

0 commit comments

Comments
 (0)