Skip to content

Commit 6549ee7

Browse files
committed
CLN: separate equations function for DiscreteSINDy and code cleanup
1 parent 56dc1c6 commit 6549ee7

File tree

3 files changed

+52
-32
lines changed

3 files changed

+52
-32
lines changed

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"sphinx.ext.mathjax",
3232
"sphinx.ext.intersphinx",
3333
"IPython.sphinxext.ipython_console_highlighting",
34-
"matplotlib.sphinxext.plot_directive"
34+
"matplotlib.sphinxext.plot_directive",
3535
]
3636

3737
nb_execution_mode = "off"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ docs = [
6060
"sphinx==8.2.3",
6161
"pyyaml",
6262
"sphinxcontrib-apidoc",
63+
"matplotlib"
6364
]
6465
miosr = [
6566
"gurobipy>=9.5.1,!=10.0.0"

pysindy/_core.py

Lines changed: 50 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def coefficients(self):
143143
check_is_fitted(self)
144144
return self.optimizer.coef_
145145

146-
def equations(self, precision: int = 3, discrete_time=False) -> list[str]:
146+
def equations(self, precision: int = 3) -> list[str]:
147147
"""
148148
Get the right hand sides of the SINDy model equations.
149149
@@ -160,10 +160,7 @@ def equations(self, precision: int = 3, discrete_time=False) -> list[str]:
160160
input feature.
161161
"""
162162
check_is_fitted(self, "model")
163-
if discrete_time:
164-
sys_coord_names = [name + "[k]" for name in self.feature_names]
165-
else:
166-
sys_coord_names = self.feature_names
163+
sys_coord_names = self.feature_names
167164
feat_names = self.feature_library.get_feature_names(sys_coord_names)
168165

169166
def term(c, name):
@@ -882,6 +879,43 @@ def fit(
882879

883880
return self
884881

882+
def equations(self, precision: int = 3) -> list[str]:
883+
"""
884+
Get the right hand sides of the DiscreteSINDy model equations.
885+
886+
Parameters
887+
----------
888+
precision: int, optional (default 3)
889+
Number of decimal points to include for each coefficient in the
890+
equation.
891+
892+
Returns
893+
-------
894+
equations: list of strings
895+
List of strings representing the DiscreteSINDy model equations for each
896+
input feature.
897+
"""
898+
check_is_fitted(self, "model")
899+
sys_coord_names = [name + "[k]" for name in self.feature_names]
900+
feat_names = self.feature_library.get_feature_names(sys_coord_names)
901+
902+
def term(c, name):
903+
rounded_coef = np.round(c, precision)
904+
if rounded_coef == 0:
905+
return ""
906+
else:
907+
return f"{c:.{precision}f} {name}"
908+
909+
equations = []
910+
for coef_row in self.optimizer.coef_:
911+
components = [term(c, i) for c, i in zip(coef_row, feat_names)]
912+
eq = " + ".join(filter(bool, components))
913+
if not eq:
914+
eq = f"{0:.{precision}f}"
915+
equations.append(eq)
916+
917+
return equations
918+
885919
def print(self, precision=3, **kwargs):
886920
"""Print the DiscreteSINDy model equations.
887921
@@ -892,7 +926,7 @@ def print(self, precision=3, **kwargs):
892926
893927
**kwargs: Additional keyword arguments passed to the builtin print function
894928
"""
895-
eqns = self.equations(precision, discrete_time=True)
929+
eqns = self.equations(precision)
896930
feature_names = self.feature_names
897931
for i, eqn in enumerate(eqns):
898932
names = f"({feature_names[i]})[k+1]"
@@ -1051,32 +1085,17 @@ def _check_multiple_trajectories(x, x_dot, u) -> bool:
10511085
10521086
"""
10531087
SequenceOrNone = Union[Sequence, None]
1054-
if sys.version_info.minor < 10:
1055-
mixed_trajectories = (
1056-
isinstance(x, Sequence)
1057-
and (
1058-
not isinstance(x_dot, Sequence)
1059-
and x_dot is not None
1060-
or not isinstance(u, Sequence)
1061-
and u is not None
1062-
)
1063-
or isinstance(x_dot, Sequence)
1064-
and not isinstance(x, Sequence)
1065-
or isinstance(u, Sequence)
1066-
and not isinstance(x, Sequence)
1067-
)
1068-
else:
1069-
mixed_trajectories = (
1070-
isinstance(x, Sequence)
1071-
and (
1072-
not isinstance(x_dot, SequenceOrNone)
1073-
or not isinstance(u, SequenceOrNone)
1074-
)
1075-
or isinstance(x_dot, Sequence)
1076-
and not isinstance(x, Sequence)
1077-
or isinstance(u, Sequence)
1078-
and not isinstance(x, Sequence)
1088+
mixed_trajectories = (
1089+
isinstance(x, Sequence)
1090+
and (
1091+
not isinstance(x_dot, SequenceOrNone)
1092+
or not isinstance(u, SequenceOrNone)
10791093
)
1094+
or isinstance(x_dot, Sequence)
1095+
and not isinstance(x, Sequence)
1096+
or isinstance(u, Sequence)
1097+
and not isinstance(x, Sequence)
1098+
)
10801099
if mixed_trajectories:
10811100
raise TypeError(
10821101
"If x, x_dot, or u are a Sequence of trajectories, each must be a Sequence"

0 commit comments

Comments
 (0)