Skip to content

Commit 38868a8

Browse files
committed
CLN: use predict from BaseSINDy for SINDy and DiscreteSINDy models
1 parent 9cf9d28 commit 38868a8

File tree

1 file changed

+53
-103
lines changed

1 file changed

+53
-103
lines changed

pysindy/pysindy.py

Lines changed: 53 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,59 @@ def _fit_shape(self):
8282
feature_names.append("u" + str(i))
8383
self.feature_names = feature_names
8484

85+
def predict(self, x, u=None):
86+
"""
87+
Predict the time derivatives if it is a SINDy model.
88+
Predict the next state of the system if it is a DiscreteSINDy model.
89+
90+
91+
Parameters
92+
----------
93+
x: array-like or list of array-like, shape (n_samples, n_input_features)
94+
Samples.
95+
96+
u: array-like or list of array-like, shape(n_samples, n_control_features), \
97+
(default None)
98+
Control variables. If ``multiple_trajectories==True`` then u
99+
must be a list of control variable data from each trajectory. If the
100+
model was fit with control variables then u is not optional.
101+
102+
Returns
103+
-------
104+
x_next: array-like or list of array-like, shape (n_samples, n_input_features)
105+
Predicted next state of the system
106+
"""
107+
if not _check_multiple_trajectories(x, None, u):
108+
x, _, _, u = _adapt_to_multiple_trajectories(x, None, None, u)
109+
multiple_trajectories = False
110+
else:
111+
multiple_trajectories = True
112+
113+
x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library)
114+
115+
check_is_fitted(self, "model")
116+
if self.n_control_features_ > 0 and u is None:
117+
raise TypeError("Model was fit using control variables, so u is required")
118+
if self.n_control_features_ == 0 and u is not None:
119+
warnings.warn(
120+
"Control variables u were ignored because control variables were"
121+
" not used when the model was fit"
122+
)
123+
u = None
124+
if u is not None:
125+
u = validate_control_variables(x, u)
126+
x = [np.concatenate((xi, ui), axis=xi.ax_coord) for xi, ui in zip(x, u)]
127+
result = [self.model.predict([xi]) for xi in x]
128+
result = [
129+
self.feature_library.reshape_samples_to_spatial_grid(pred)
130+
for pred in result
131+
]
132+
133+
# Kept for backwards compatibility.
134+
if not multiple_trajectories:
135+
return result[0]
136+
return result
137+
85138
def coefficients(self):
86139
"""
87140
Get an array of the coefficients learned by SINDy model.
@@ -367,57 +420,6 @@ def fit(
367420

368421
return self
369422

370-
def predict(self, x, u=None):
371-
"""
372-
Predict the time derivatives using the SINDy model.
373-
374-
Parameters
375-
----------
376-
x: array-like or list of array-like, shape (n_samples, n_input_features)
377-
Samples.
378-
379-
u: array-like or list of array-like, shape(n_samples, n_control_features), \
380-
(default None)
381-
Control variables. If ``multiple_trajectories==True`` then u
382-
must be a list of control variable data from each trajectory. If the
383-
model was fit with control variables then u is not optional.
384-
385-
Returns
386-
-------
387-
x_dot: array-like or list of array-like, shape (n_samples, n_input_features)
388-
Predicted time derivatives
389-
"""
390-
if not _check_multiple_trajectories(x, None, u):
391-
x, _, _, u = _adapt_to_multiple_trajectories(x, None, None, u)
392-
multiple_trajectories = False
393-
else:
394-
multiple_trajectories = True
395-
396-
x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library)
397-
398-
check_is_fitted(self, "model")
399-
if self.n_control_features_ > 0 and u is None:
400-
raise TypeError("Model was fit using control variables, so u is required")
401-
if self.n_control_features_ == 0 and u is not None:
402-
warnings.warn(
403-
"Control variables u were ignored because control variables were"
404-
" not used when the model was fit"
405-
)
406-
u = None
407-
if u is not None:
408-
u = validate_control_variables(x, u)
409-
x = [np.concatenate((xi, ui), axis=xi.ax_coord) for xi, ui in zip(x, u)]
410-
result = [self.model.predict([xi]) for xi in x]
411-
result = [
412-
self.feature_library.reshape_samples_to_spatial_grid(pred)
413-
for pred in result
414-
]
415-
416-
# Kept for backwards compatibility.
417-
if not multiple_trajectories:
418-
return result[0]
419-
return result
420-
421423
def print(self, lhs=None, precision=3, **kwargs):
422424
"""Print the SINDy model equations.
423425
@@ -885,58 +887,6 @@ def fit(
885887

886888
return self
887889

888-
def predict(self, x, u=None):
889-
"""
890-
Predict the time derivatives using the DiscreteSINDy model.
891-
892-
Parameters
893-
----------
894-
x: array-like or list of array-like, shape (n_samples, n_input_features)
895-
Samples.
896-
897-
u: array-like or list of array-like, shape(n_samples, n_control_features), \
898-
(default None)
899-
Control variables. If ``multiple_trajectories==True`` then u
900-
must be a list of control variable data from each trajectory. If the
901-
model was fit with control variables then u is not optional.
902-
903-
Returns
904-
-------
905-
x_next: array-like or list of array-like, shape (n_samples, n_input_features)
906-
Predicted next state of the system
907-
"""
908-
if not _check_multiple_trajectories(x, None, u):
909-
x, _, _, u = _adapt_to_multiple_trajectories(x, None, None, u)
910-
multiple_trajectories = False
911-
else:
912-
multiple_trajectories = True
913-
914-
x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library)
915-
916-
check_is_fitted(self, "model")
917-
if self.n_control_features_ > 0 and u is None:
918-
raise TypeError("Model was fit using control variables, so u is required")
919-
if self.n_control_features_ == 0 and u is not None:
920-
warnings.warn(
921-
"Control variables u were ignored because control variables were"
922-
" not used when the model was fit"
923-
)
924-
u = None
925-
x = [validate_input(xi) for xi in x]
926-
if u is not None:
927-
u = validate_control_variables(x, u)
928-
x = [np.concatenate((xi, ui), axis=xi.ax_coord) for xi, ui in zip(x, u)]
929-
result = [self.model.predict([xi]) for xi in x]
930-
result = [
931-
self.feature_library.reshape_samples_to_spatial_grid(pred)
932-
for pred in result
933-
]
934-
935-
# Kept for backwards compatibility.
936-
if not multiple_trajectories:
937-
return result[0]
938-
return result
939-
940890
def print(self, precision=3, **kwargs):
941891
"""Print the DiscreteSINDy model equations.
942892

0 commit comments

Comments
 (0)