@@ -726,23 +726,20 @@ def __init__(
726726 self ,
727727 optimizer : Optional [BaseOptimizer ] = None ,
728728 feature_library : Optional [BaseFeatureLibrary ] = None ,
729- differentiation_method : Optional [BaseDifferentiation ] = None ,
730729 ):
731730 if optimizer is None :
732731 optimizer = STLSQ ()
733732 self .optimizer = optimizer
734733 if feature_library is None :
735734 feature_library = PolynomialLibrary ()
736735 self .feature_library = feature_library
737- if differentiation_method is None :
738- differentiation_method = FiniteDifference (axis = - 2 )
739- self .differentiation_method = differentiation_method
740736
741737 def fit (
742738 self ,
743739 x ,
744740 t ,
745741 u = None ,
742+ x_next = None ,
746743 feature_names : Optional [list [str ]] = None ,
747744 ):
748745 """
@@ -782,21 +779,23 @@ def fit(
782779 self: a fitted :class:`DiscreteSINDy` instance
783780 """
784781
785- if not _check_multiple_trajectories (x , None , u ):
786- x , t , _ , u = _adapt_to_multiple_trajectories (x , t , None , u )
787- x , _ , u = _comprehend_and_validate_inputs (
788- x , t , None , u , self .feature_library
782+ if not _check_multiple_trajectories (x , x_next , u ):
783+ x , t , x_next , u = _adapt_to_multiple_trajectories (x , t , x_next , u )
784+ x , x_next , u = _comprehend_and_validate_inputs (
785+ x , t , x_next , u , self .feature_library
789786 )
790787
791- x_next = [xi [1 :] for xi in x ]
792- x = [xi [:- 1 ] for xi in x ]
788+ if x_next is None :
789+ x_next = [xi [1 :] for xi in x ]
790+ x = [xi [:- 1 ] for xi in x ]
791+ if u is not None :
792+ u = [ui [:- 1 ] for ui in u ]
793793
794794 # Append control variables
795795 if u is None :
796796 self .n_control_features_ = 0
797797 else :
798798 u = validate_control_variables (x , u )
799- u = u [:- 1 ]
800799 self .n_control_features_ = u [0 ].n_coord
801800
802801 x = [np .concatenate ((xi , ui ), axis = xi .ax_coord ) for xi , ui in zip (x , u )]
@@ -886,7 +885,7 @@ def print(self, precision=3, **kwargs):
886885 names = f"({ feature_names [i ]} )[k+1]"
887886 print (f"{ names } = { eqn } " , ** kwargs )
888887
889- def score (self , x , t , u = None , metric = r2_score , ** metric_kws ):
888+ def score (self , x , t , u = None , x_next = None , metric = r2_score , ** metric_kws ):
890889 """
891890 Returns a score for the next state prediction produced by the model.
892891
@@ -922,17 +921,18 @@ def score(self, x, t, u=None, metric=r2_score, **metric_kws):
922921 Metric function value for the model prediction of x_next.
923922 """
924923
925- if not _check_multiple_trajectories (x , None , u ):
926- x , t , _ , u = _adapt_to_multiple_trajectories (x , t , None , u )
927- x , _ , u = _comprehend_and_validate_inputs (
928- x , t , None , u , self .feature_library
924+ if not _check_multiple_trajectories (x , x_next , u ):
925+ x , t , x_next , u = _adapt_to_multiple_trajectories (x , t , x_next , u )
926+ x , x_next , u = _comprehend_and_validate_inputs (
927+ x , t , x_next , u , self .feature_library
929928 )
930929
931930 x_next_predict = self .predict (x , u )
932- x_next_predict = [xd [:- 1 ] for xd in x_next_predict ]
933931
934- x_next = [xi [1 :] for xi in x ]
935- x = [xi [:- 1 ] for xi in x ]
932+ if x_next is None :
933+ x_next_predict = [xd [:- 1 ] for xd in x_next_predict ]
934+ x_next = [xi [1 :] for xi in x ]
935+ x = [xi [:- 1 ] for xi in x ]
936936
937937 x_next = concat_sample_axis (x_next )
938938 x_next_predict = concat_sample_axis (x_next_predict )
@@ -963,11 +963,6 @@ def simulate(
963963 A list (with ``len(u) == t``) or array (with ``u.shape[0] == 1``)
964964 giving the control inputs at each step.
965965
966- integrator: string, optional (default ``solve_ivp``)
967- Function to use to integrate the system.
968- Default is ``scipy.integrate.solve_ivp``. The only options
969- currently supported are solve_ivp and odeint.
970-
971966 stop_condition: function object, optional
972967 If model is in discrete time, optional function that gives a
973968 stopping condition for stepping the simulation forward.
0 commit comments