Skip to content

Commit 5096ba3

Browse files
authored
Move feature_names from SINDy.__init__() to SINDy.fit() (#635)
Fixes #387
1 parent ee7eda1 commit 5096ba3

File tree

9 files changed

+1073
-1028
lines changed

9 files changed

+1073
-1028
lines changed

examples/1_feature_overview/example.ipynb

Lines changed: 481 additions & 854 deletions
Large diffs are not rendered by default.

examples/1_feature_overview/example.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,10 @@ def f(x):
270270
df = pd.DataFrame(data=x_train, columns=["x", "y", "z"], index=t_train)
271271

272272
# The column names can be used as feature names
273-
model = ps.SINDy(feature_names=df.columns)
273+
model = ps.SINDy()
274274

275275
# Everything needs to be converted to numpy arrays to be passed in
276-
model.fit(df.values, t=df.index.values)
276+
model.fit(df.values, t=df.index.values, feature_names=df.columns)
277277
model.print()
278278

279279
# %% [markdown]
@@ -628,8 +628,8 @@ def f(x):
628628
n_subset=int(0.6 * x_train.shape[0]),
629629
)
630630

631-
model = ps.SINDy(optimizer=ensemble_optimizer, feature_names=feature_names)
632-
model.fit(x_train, t=dt)
631+
model = ps.SINDy(optimizer=ensemble_optimizer)
632+
model.fit(x_train, t=dt, feature_names=feature_names)
633633
ensemble_coefs = ensemble_optimizer.coef_list
634634
mean_ensemble = np.mean(ensemble_coefs, axis=0)
635635
std_ensemble = np.std(ensemble_coefs, axis=0)
@@ -638,9 +638,9 @@ def f(x):
638638
library_ensemble_optimizer = ps.EnsembleOptimizer(
639639
ps.STLSQ(threshold=1e-3, normalize_columns=False), library_ensemble=True
640640
)
641-
model = ps.SINDy(optimizer=library_ensemble_optimizer, feature_names=feature_names)
641+
model = ps.SINDy(optimizer=library_ensemble_optimizer)
642642

643-
model.fit(x_train, t=dt)
643+
model.fit(x_train, t=dt, feature_names=feature_names)
644644
library_ensemble_coefs = library_ensemble_optimizer.coef_list
645645
mean_library_ensemble = np.mean(library_ensemble_coefs, axis=0)
646646
std_library_ensemble = np.std(library_ensemble_coefs, axis=0)
@@ -768,8 +768,8 @@ def f(x):
768768

769769
# In[47]:
770770
feature_names = ["x", "y", "z"]
771-
model = ps.SINDy(feature_names=feature_names)
772-
model.fit(x_train, t=dt)
771+
model = ps.SINDy()
772+
model.fit(x_train, t=dt, feature_names=feature_names)
773773
model.print()
774774

775775
# %% [markdown]
@@ -865,8 +865,8 @@ def f(x):
865865
fourier_library = ps.FourierLibrary()
866866
combined_library = identity_library + fourier_library
867867

868-
model = ps.SINDy(feature_library=combined_library, feature_names=feature_names)
869-
model.fit(x_train, t=dt)
868+
model = ps.SINDy(feature_library=combined_library)
869+
model.fit(x_train, t=dt, feature_names=feature_names)
870870
model.print()
871871
model.get_feature_names()
872872

@@ -879,8 +879,8 @@ def f(x):
879879
fourier_library = ps.FourierLibrary()
880880
combined_library = identity_library * fourier_library
881881

882-
model = ps.SINDy(feature_library=combined_library, feature_names=feature_names)
883-
model.fit(x_train, t=dt)
882+
model = ps.SINDy(feature_library=combined_library)
883+
model.fit(x_train, t=dt, feature_names=feature_names)
884884
# model.print() # prints out long and unobvious model
885885
print("Feature names:\n", model.get_feature_names())
886886

@@ -940,8 +940,8 @@ def f(x):
940940
)
941941

942942
# Fit the model and print the library feature names to check success
943-
model = ps.SINDy(feature_library=generalized_library, feature_names=feature_names)
944-
model.fit(x_train, t=dt)
943+
model = ps.SINDy(feature_library=generalized_library)
944+
model.fit(x_train, t=dt, feature_names=feature_names)
945945
model.print()
946946
print("Feature names:\n", model.get_feature_names())
947947

@@ -1216,10 +1216,8 @@ def u_fun(t):
12161216
num_parameters=1,
12171217
)
12181218
opt = ps.STLSQ(threshold=1e-1, normalize_columns=False)
1219-
model = ps.SINDy(
1220-
feature_library=lib, optimizer=opt, feature_names=["x", "r"], discrete_time=True
1221-
)
1222-
model.fit(xs_train, u=rs_train, t=1)
1219+
model = ps.SINDy(feature_library=lib, optimizer=opt, discrete_time=True)
1220+
model.fit(xs_train, u=rs_train, t=1, feature_names=["x", "r"])
12231221
model.print()
12241222

12251223
# %% [markdown]

0 commit comments

Comments
 (0)