@@ -143,7 +143,7 @@ def coefficients(self):
143143 check_is_fitted (self )
144144 return self .optimizer .coef_
145145
146- def equations (self , precision : int = 3 , discrete_time = False ) -> list [str ]:
146+ def equations (self , precision : int = 3 ) -> list [str ]:
147147 """
148148 Get the right hand sides of the SINDy model equations.
149149
@@ -160,10 +160,7 @@ def equations(self, precision: int = 3, discrete_time=False) -> list[str]:
160160 input feature.
161161 """
162162 check_is_fitted (self , "model" )
163- if discrete_time :
164- sys_coord_names = [name + "[k]" for name in self .feature_names ]
165- else :
166- sys_coord_names = self .feature_names
163+ sys_coord_names = self .feature_names
167164 feat_names = self .feature_library .get_feature_names (sys_coord_names )
168165
169166 def term (c , name ):
@@ -882,6 +879,43 @@ def fit(
882879
883880 return self
884881
882+ def equations (self , precision : int = 3 ) -> list [str ]:
883+ """
884+ Get the right hand sides of the DiscreteSINDy model equations.
885+
886+ Parameters
887+ ----------
888+ precision: int, optional (default 3)
889+ Number of decimal points to include for each coefficient in the
890+ equation.
891+
892+ Returns
893+ -------
894+ equations: list of strings
895+ List of strings representing the DiscreteSINDy model equations for each
896+ input feature.
897+ """
898+ check_is_fitted (self , "model" )
899+ sys_coord_names = [name + "[k]" for name in self .feature_names ]
900+ feat_names = self .feature_library .get_feature_names (sys_coord_names )
901+
902+ def term (c , name ):
903+ rounded_coef = np .round (c , precision )
904+ if rounded_coef == 0 :
905+ return ""
906+ else :
907+ return f"{ c :.{precision }f} { name } "
908+
909+ equations = []
910+ for coef_row in self .optimizer .coef_ :
911+ components = [term (c , i ) for c , i in zip (coef_row , feat_names )]
912+ eq = " + " .join (filter (bool , components ))
913+ if not eq :
914+ eq = f"{ 0 :.{precision }f} "
915+ equations .append (eq )
916+
917+ return equations
918+
885919 def print (self , precision = 3 , ** kwargs ):
886920 """Print the DiscreteSINDy model equations.
887921
@@ -892,7 +926,7 @@ def print(self, precision=3, **kwargs):
892926
893927 **kwargs: Additional keyword arguments passed to the builtin print function
894928 """
895- eqns = self .equations (precision , discrete_time = True )
929+ eqns = self .equations (precision )
896930 feature_names = self .feature_names
897931 for i , eqn in enumerate (eqns ):
898932 names = f"({ feature_names [i ]} )[k+1]"
@@ -1051,32 +1085,17 @@ def _check_multiple_trajectories(x, x_dot, u) -> bool:
10511085
10521086 """
10531087 SequenceOrNone = Union [Sequence , None ]
1054- if sys .version_info .minor < 10 :
1055- mixed_trajectories = (
1056- isinstance (x , Sequence )
1057- and (
1058- not isinstance (x_dot , Sequence )
1059- and x_dot is not None
1060- or not isinstance (u , Sequence )
1061- and u is not None
1062- )
1063- or isinstance (x_dot , Sequence )
1064- and not isinstance (x , Sequence )
1065- or isinstance (u , Sequence )
1066- and not isinstance (x , Sequence )
1067- )
1068- else :
1069- mixed_trajectories = (
1070- isinstance (x , Sequence )
1071- and (
1072- not isinstance (x_dot , SequenceOrNone )
1073- or not isinstance (u , SequenceOrNone )
1074- )
1075- or isinstance (x_dot , Sequence )
1076- and not isinstance (x , Sequence )
1077- or isinstance (u , Sequence )
1078- and not isinstance (x , Sequence )
1088+ mixed_trajectories = (
1089+ isinstance (x , Sequence )
1090+ and (
1091+ not isinstance (x_dot , SequenceOrNone )
1092+ or not isinstance (u , SequenceOrNone )
10791093 )
1094+ or isinstance (x_dot , Sequence )
1095+ and not isinstance (x , Sequence )
1096+ or isinstance (u , Sequence )
1097+ and not isinstance (x , Sequence )
1098+ )
10801099 if mixed_trajectories :
10811100 raise TypeError (
10821101 "If x, x_dot, or u are a Sequence of trajectories, each must be a Sequence"
0 commit comments