Skip to content

Commit 404d6b3

Browse files
committed
CLN: lint the code
1 parent 393f3e1 commit 404d6b3

File tree

8 files changed

+35
-38
lines changed

8 files changed

+35
-38
lines changed

pysindy/feature_library/polynomial_library.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,11 @@ def get_feature_names(self, input_features=None):
150150
inds = np.where(row)[0]
151151
if len(inds):
152152
name = " ".join(
153-
"%s^%d" % (input_features[ind], exp)
154-
if exp != 1
155-
else input_features[ind]
153+
(
154+
"%s^%d" % (input_features[ind], exp)
155+
if exp != 1
156+
else input_features[ind]
157+
)
156158
for ind, exp in zip(inds, row[inds])
157159
)
158160
else:

pysindy/optimizers/trapping_sr3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ def find_symm_term(a, b):
861861

862862

863863
def _build_lib_info(
864-
polyterms: list[tuple[int, Int1D]]
864+
polyterms: list[tuple[int, Int1D]],
865865
) -> tuple[int, int, dict[int, int], dict[int, int], dict[frozenset[int], int]]:
866866
"""From polynomial, calculate various useful info
867867

pysindy/pysindy.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from .utils import drop_nan_samples
3939
from .utils import SampleConcatter
4040
from .utils import validate_control_variables
41-
from .utils import validate_input
4241
from .utils import validate_no_reshape
4342

4443

@@ -568,17 +567,17 @@ def simulate(
568567
Initial condition from which to simulate.
569568
570569
t: numpy array of size [n_samples]
571-
Array of time points at which to simulate.
570+
Array of time points at which to simulate.
572571
573572
u: function from R^1 to R^{n_control_features} or list/array, optional \
574573
(default None)
575574
Control inputs.
576-
``u`` can be a function that would take time as input and output the values of each of
577-
the n_control_features control features as a list or numpy array.
578-
Alternatively, ``u`` can also be an array of control inputs at each time
579-
step. In this case, the array is fit with the interpolator specified
580-
by ``interpolator``.
581-
575+
``u`` can be a function that would take time as input and output
576+
the values of each of the n_control_features control features as
577+
a list or numpy array. Alternatively, ``u`` can also be an array
578+
of control inputs at each time step. In this case, the array is fit
579+
with the interpolator specified by ``interpolator``.
580+
582581
integrator: string, optional (default ``solve_ivp``)
583582
Function to use to integrate the system.
584583
Default is ``scipy.integrate.solve_ivp``. The only options
@@ -640,9 +639,7 @@ def rhs(t, x):
640639
if u_fun(t[0]).ndim == 1:
641640

642641
def rhs(t, x):
643-
return self.predict(x[np.newaxis, :], u_fun(t).reshape(1, -1))[
644-
0
645-
]
642+
return self.predict(x[np.newaxis, :], u_fun(t).reshape(1, -1))[0]
646643

647644
else:
648645

@@ -652,9 +649,7 @@ def rhs(t, x):
652649
# Need to hard-code below, because odeint and solve_ivp
653650
# have different syntax and integration options.
654651
if integrator == "solve_ivp":
655-
return (
656-
(solve_ivp(rhs, (t[0], t[-1]), x0, t_eval=t, **integrator_kws)).y
657-
).T
652+
return ((solve_ivp(rhs, (t[0], t[-1]), x0, t_eval=t, **integrator_kws)).y).T
658653
elif integrator == "odeint":
659654
if integrator_kws.get("method") == "LSODA":
660655
integrator_kws = {}
@@ -672,7 +667,8 @@ def complexity(self):
672667

673668
class DiscreteSINDy(_BaseSINDy):
674669
"""
675-
Sparse Identification of Nonlinear Dynamical Systems (SINDy) for discrete time systems.
670+
Sparse Identification of Nonlinear Dynamical Systems (SINDy) for discrete
671+
time systems.
676672
677673
Parameters
678674
----------
@@ -777,9 +773,10 @@ def fit(
777773
Parameters
778774
----------
779775
x: array-like or list of array-like, shape (n_samples, n_input_features)
780-
Training data of the current state of the system. If training data
781-
contains multiple trajectories, x should be a list containing data for
782-
each trajectory. Individual trajectories may contain different numbers of samples.
776+
Training data of the current state of the system. If training data
777+
contains multiple trajectories, x should be a list containing data for
778+
each trajectory. Individual trajectories may contain different numbers
779+
of samples.
783780
784781
t: float, numpy array of shape (n_samples,), or list of numpy arrays
785782
If t is a float, it specifies the timestep between each sample.
@@ -792,10 +789,10 @@ def fit(
792789
793790
x_next: array-like or list of array-like, shape (n_samples, n_input_features), \
794791
optional (default None)
795-
Optional data of the system forwarded by one time step. If not provided, the
792+
Optional data of the system forwarded by one time step. If not provided, the
796793
next will be computed by taking the training data by one time step.
797-
If x_next is provided, it must match the shape of the training data and these
798-
values will be used as the next state.
794+
If x_next is provided, it must match the shape of the training data and
795+
these values will be used as the next state.
799796
800797
u: array-like or list of array-like, shape (n_samples, n_control_features), \
801798
optional (default None)
@@ -944,7 +941,7 @@ def simulate(
944941
945942
u: list/array, optional (default None)
946943
Control inputs.
947-
A list (with ``len(u) == t``) or array (with ``u.shape[0] == 1``)
944+
A list (with ``len(u) == t``) or array (with ``u.shape[0] == 1``)
948945
giving the control inputs at each step.
949946
950947
stop_condition: function object, optional
@@ -961,15 +958,15 @@ def simulate(
961958
raise TypeError("Model was fit using control variables, so u is required")
962959

963960
if not isinstance(t, int) or t <= 0:
964-
raise ValueError(
965-
" t must be an integer"
966-
)
961+
raise ValueError(" t must be an integer")
967962

968963
if stop_condition is not None:
964+
969965
def check_stop_condition(xi):
970966
return stop_condition(xi)
971967

972968
else:
969+
973970
def check_stop_condition(xi):
974971
pass
975972

@@ -992,7 +989,7 @@ def check_stop_condition(xi):
992989
if check_stop_condition(x[i]):
993990
return x[: i + 1]
994991
return x
995-
992+
996993

997994
def _zip_like_sequence(x, t):
998995
"""Create an iterable like zip(x, t), but works if t is scalar.

pysindy/utils/axes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def _tensordot_to_einsum(
676676
for a_ind, b_ind in zip(*axes):
677677
sub_b_li[b_ind] = sub_a[a_ind]
678678
sub_b = "".join(sub_b_li)
679-
sub = f"{sub_a},{sub_b}"
679+
sub = f"{sub_a}, {sub_b}"
680680
return sub
681681

682682

pysindy/utils/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,7 @@ def validate_no_reshape(x, t: Union[float, np.ndarray, object] = T_DEFAULT):
9595
return x
9696

9797

98-
def validate_control_variables(
99-
x: Sequence[AxesArray], u: Sequence[AxesArray]
100-
) -> None:
98+
def validate_control_variables(x: Sequence[AxesArray], u: Sequence[AxesArray]) -> None:
10199
"""Ensure that control variables u are compatible with the data x.
102100
103101
Args:

test/test_pysindy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
from sklearn.model_selection import TimeSeriesSplit
2323
from sklearn.utils.validation import check_is_fitted
2424

25+
from pysindy import DiscreteSINDy
2526
from pysindy import pysindy
26-
from pysindy import SINDy, DiscreteSINDy
27+
from pysindy import SINDy
2728
from pysindy.differentiation import SINDyDerivative
2829
from pysindy.differentiation import SmoothedFiniteDifference
2930
from pysindy.feature_library import FourierLibrary

test/test_sindyc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from sklearn.linear_model import Lasso
99
from sklearn.utils.validation import check_is_fitted
1010

11-
from pysindy import SINDy, DiscreteSINDy
11+
from pysindy import DiscreteSINDy
12+
from pysindy import SINDy
1213
from pysindy.optimizers import SR3
1314
from pysindy.optimizers import STLSQ
1415

test/utils/test_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
import pytest
33
from numpy.testing import assert_array_equal
44

5-
from pysindy.utils import AxesArray
65
from pysindy.utils import get_prox
76
from pysindy.utils import get_regularization
87
from pysindy.utils import reorder_constraints
9-
from pysindy.utils import validate_control_variables
108

119

1210
def test_reorder_constraints_1D():

0 commit comments

Comments
 (0)