@@ -82,6 +82,59 @@ def _fit_shape(self):
8282 feature_names .append ("u" + str (i ))
8383 self .feature_names = feature_names
8484
85+ def predict (self , x , u = None ):
86+ """
87+ Predict the time derivatives if it is a SINDy model.
88+ Predict the next state of the system if it is a DiscreteSINDy model.
89+
90+
91+ Parameters
92+ ----------
93+ x: array-like or list of array-like, shape (n_samples, n_input_features)
94+ Samples.
95+
96+ u: array-like or list of array-like, shape(n_samples, n_control_features), \
97+ (default None)
98+ Control variables. If ``multiple_trajectories==True`` then u
99+ must be a list of control variable data from each trajectory. If the
100+ model was fit with control variables then u is not optional.
101+
102+ Returns
103+ -------
104+ x_next: array-like or list of array-like, shape (n_samples, n_input_features)
105+ Predicted next state of the system
106+ """
107+ if not _check_multiple_trajectories (x , None , u ):
108+ x , _ , _ , u = _adapt_to_multiple_trajectories (x , None , None , u )
109+ multiple_trajectories = False
110+ else :
111+ multiple_trajectories = True
112+
113+ x , _ , u = _comprehend_and_validate_inputs (x , 1 , None , u , self .feature_library )
114+
115+ check_is_fitted (self , "model" )
116+ if self .n_control_features_ > 0 and u is None :
117+ raise TypeError ("Model was fit using control variables, so u is required" )
118+ if self .n_control_features_ == 0 and u is not None :
119+ warnings .warn (
120+ "Control variables u were ignored because control variables were"
121+ " not used when the model was fit"
122+ )
123+ u = None
124+ if u is not None :
125+ u = validate_control_variables (x , u )
126+ x = [np .concatenate ((xi , ui ), axis = xi .ax_coord ) for xi , ui in zip (x , u )]
127+ result = [self .model .predict ([xi ]) for xi in x ]
128+ result = [
129+ self .feature_library .reshape_samples_to_spatial_grid (pred )
130+ for pred in result
131+ ]
132+
133+ # Kept for backwards compatibility.
134+ if not multiple_trajectories :
135+ return result [0 ]
136+ return result
137+
85138 def coefficients (self ):
86139 """
87140 Get an array of the coefficients learned by SINDy model.
@@ -367,57 +420,6 @@ def fit(
367420
368421 return self
369422
370- def predict (self , x , u = None ):
371- """
372- Predict the time derivatives using the SINDy model.
373-
374- Parameters
375- ----------
376- x: array-like or list of array-like, shape (n_samples, n_input_features)
377- Samples.
378-
379- u: array-like or list of array-like, shape(n_samples, n_control_features), \
380- (default None)
381- Control variables. If ``multiple_trajectories==True`` then u
382- must be a list of control variable data from each trajectory. If the
383- model was fit with control variables then u is not optional.
384-
385- Returns
386- -------
387- x_dot: array-like or list of array-like, shape (n_samples, n_input_features)
388- Predicted time derivatives
389- """
390- if not _check_multiple_trajectories (x , None , u ):
391- x , _ , _ , u = _adapt_to_multiple_trajectories (x , None , None , u )
392- multiple_trajectories = False
393- else :
394- multiple_trajectories = True
395-
396- x , _ , u = _comprehend_and_validate_inputs (x , 1 , None , u , self .feature_library )
397-
398- check_is_fitted (self , "model" )
399- if self .n_control_features_ > 0 and u is None :
400- raise TypeError ("Model was fit using control variables, so u is required" )
401- if self .n_control_features_ == 0 and u is not None :
402- warnings .warn (
403- "Control variables u were ignored because control variables were"
404- " not used when the model was fit"
405- )
406- u = None
407- if u is not None :
408- u = validate_control_variables (x , u )
409- x = [np .concatenate ((xi , ui ), axis = xi .ax_coord ) for xi , ui in zip (x , u )]
410- result = [self .model .predict ([xi ]) for xi in x ]
411- result = [
412- self .feature_library .reshape_samples_to_spatial_grid (pred )
413- for pred in result
414- ]
415-
416- # Kept for backwards compatibility.
417- if not multiple_trajectories :
418- return result [0 ]
419- return result
420-
421423 def print (self , lhs = None , precision = 3 , ** kwargs ):
422424 """Print the SINDy model equations.
423425
@@ -885,58 +887,6 @@ def fit(
885887
886888 return self
887889
888- def predict (self , x , u = None ):
889- """
890- Predict the time derivatives using the DiscreteSINDy model.
891-
892- Parameters
893- ----------
894- x: array-like or list of array-like, shape (n_samples, n_input_features)
895- Samples.
896-
897- u: array-like or list of array-like, shape(n_samples, n_control_features), \
898- (default None)
899- Control variables. If ``multiple_trajectories==True`` then u
900- must be a list of control variable data from each trajectory. If the
901- model was fit with control variables then u is not optional.
902-
903- Returns
904- -------
905- x_next: array-like or list of array-like, shape (n_samples, n_input_features)
906- Predicted next state of the system
907- """
908- if not _check_multiple_trajectories (x , None , u ):
909- x , _ , _ , u = _adapt_to_multiple_trajectories (x , None , None , u )
910- multiple_trajectories = False
911- else :
912- multiple_trajectories = True
913-
914- x , _ , u = _comprehend_and_validate_inputs (x , 1 , None , u , self .feature_library )
915-
916- check_is_fitted (self , "model" )
917- if self .n_control_features_ > 0 and u is None :
918- raise TypeError ("Model was fit using control variables, so u is required" )
919- if self .n_control_features_ == 0 and u is not None :
920- warnings .warn (
921- "Control variables u were ignored because control variables were"
922- " not used when the model was fit"
923- )
924- u = None
925- x = [validate_input (xi ) for xi in x ]
926- if u is not None :
927- u = validate_control_variables (x , u )
928- x = [np .concatenate ((xi , ui ), axis = xi .ax_coord ) for xi , ui in zip (x , u )]
929- result = [self .model .predict ([xi ]) for xi in x ]
930- result = [
931- self .feature_library .reshape_samples_to_spatial_grid (pred )
932- for pred in result
933- ]
934-
935- # Kept for backwards compatibility.
936- if not multiple_trajectories :
937- return result [0 ]
938- return result
939-
940890 def print (self , precision = 3 , ** kwargs ):
941891 """Print the DiscreteSINDy model equations.
942892
0 commit comments