Skip to content

Commit bbc6c2c

Browse files
committed
CLN: ensure consistency of shapes and types of x_next and u, and compatibility with multiple trajectories for DiscreteSINDy
1 parent 9df5bcc commit bbc6c2c

File tree

2 files changed

+19
-25
lines changed

2 files changed

+19
-25
lines changed

pysindy/pysindy.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -726,23 +726,20 @@ def __init__(
726726
self,
727727
optimizer: Optional[BaseOptimizer] = None,
728728
feature_library: Optional[BaseFeatureLibrary] = None,
729-
differentiation_method: Optional[BaseDifferentiation] = None,
730729
):
731730
if optimizer is None:
732731
optimizer = STLSQ()
733732
self.optimizer = optimizer
734733
if feature_library is None:
735734
feature_library = PolynomialLibrary()
736735
self.feature_library = feature_library
737-
if differentiation_method is None:
738-
differentiation_method = FiniteDifference(axis=-2)
739-
self.differentiation_method = differentiation_method
740736

741737
def fit(
742738
self,
743739
x,
744740
t,
745741
u=None,
742+
x_next=None,
746743
feature_names: Optional[list[str]] = None,
747744
):
748745
"""
@@ -782,21 +779,23 @@ def fit(
782779
self: a fitted :class:`DiscreteSINDy` instance
783780
"""
784781

785-
if not _check_multiple_trajectories(x, None, u):
786-
x, t, _, u = _adapt_to_multiple_trajectories(x, t, None, u)
787-
x, _, u = _comprehend_and_validate_inputs(
788-
x, t, None, u, self.feature_library
782+
if not _check_multiple_trajectories(x, x_next, u):
783+
x, t, x_next, u = _adapt_to_multiple_trajectories(x, t, x_next, u)
784+
x, x_next, u = _comprehend_and_validate_inputs(
785+
x, t, x_next, u, self.feature_library
789786
)
790787

791-
x_next = [xi[1:] for xi in x]
792-
x = [xi[:-1] for xi in x]
788+
if x_next is None:
789+
x_next = [xi[1:] for xi in x]
790+
x = [xi[:-1] for xi in x]
791+
if u is not None:
792+
u = [ui[:-1] for ui in u]
793793

794794
# Append control variables
795795
if u is None:
796796
self.n_control_features_ = 0
797797
else:
798798
u = validate_control_variables(x, u)
799-
u = u[:-1]
800799
self.n_control_features_ = u[0].n_coord
801800

802801
x = [np.concatenate((xi, ui), axis=xi.ax_coord) for xi, ui in zip(x, u)]
@@ -886,7 +885,7 @@ def print(self, precision=3, **kwargs):
886885
names = f"({feature_names[i]})[k+1]"
887886
print(f"{names} = {eqn}", **kwargs)
888887

889-
def score(self, x, t, u=None, metric=r2_score, **metric_kws):
888+
def score(self, x, t, u=None, x_next=None, metric=r2_score, **metric_kws):
890889
"""
891890
Returns a score for the next state prediction produced by the model.
892891
@@ -922,17 +921,18 @@ def score(self, x, t, u=None, metric=r2_score, **metric_kws):
922921
Metric function value for the model prediction of x_next.
923922
"""
924923

925-
if not _check_multiple_trajectories(x, None, u):
926-
x, t, _, u = _adapt_to_multiple_trajectories(x, t, None, u)
927-
x, _, u = _comprehend_and_validate_inputs(
928-
x, t, None, u, self.feature_library
924+
if not _check_multiple_trajectories(x, x_next, u):
925+
x, t, x_next, u = _adapt_to_multiple_trajectories(x, t, x_next, u)
926+
x, x_next, u = _comprehend_and_validate_inputs(
927+
x, t, x_next, u, self.feature_library
929928
)
930929

931930
x_next_predict = self.predict(x, u)
932-
x_next_predict = [xd[:-1] for xd in x_next_predict]
933931

934-
x_next = [xi[1:] for xi in x]
935-
x = [xi[:-1] for xi in x]
932+
if x_next is None:
933+
x_next_predict = [xd[:-1] for xd in x_next_predict]
934+
x_next = [xi[1:] for xi in x]
935+
x = [xi[:-1] for xi in x]
936936

937937
x_next = concat_sample_axis(x_next)
938938
x_next_predict = concat_sample_axis(x_next_predict)
@@ -963,11 +963,6 @@ def simulate(
963963
A list (with ``len(u) == t``) or array (with ``u.shape[0] == 1``)
964964
giving the control inputs at each step.
965965
966-
integrator: string, optional (default ``solve_ivp``)
967-
Function to use to integrate the system.
968-
Default is ``scipy.integrate.solve_ivp``. The only options
969-
currently supported are solve_ivp and odeint.
970-
971966
stop_condition: function object, optional
972967
If model is in discrete time, optional function that gives a
973968
stopping condition for stepping the simulation forward.

pysindy/utils/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def validate_control_variables(
103103
Args:
104104
x: trajectories of system variables
105105
u: trajectories of control variables
106-
trim_last_point: whether to remove last time point of controls
107106
"""
108107
if not isinstance(x, Sequence):
109108
raise ValueError("x must be a Sequence")

0 commit comments

Comments
 (0)