Skip to content

Commit 59d24ba

Browse files
authored
Merge pull request #225 from dynamicslab/cln_process_multiple
Clean _process_multiple_trajectories
2 parents 632585a + 7e02c31 commit 59d24ba

File tree

2 files changed

+7
-63
lines changed

2 files changed

+7
-63
lines changed

pysindy/pysindy.py

Lines changed: 7 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -620,11 +620,7 @@ def score(
620620

621621
def _process_multiple_trajectories(self, x, t, x_dot):
622622
"""
623-
Handle input data that contains multiple trajectories by doing the
624-
necessary validation, reshaping, and computation of derivatives.
625-
626-
This method essentially just loops over elements of each list in parallel,
627-
validates them, and (optionally) concatenates them together.
623+
Calculate derivatives of input data, iterating through trajectories.
628624
629625
Parameters
630626
----------
@@ -633,19 +629,16 @@ def _process_multiple_trajectories(self, x, t, x_dot):
633629
trajectory.
634630
635631
t: list of np.ndarray or int
636-
List of time points for different trajectories.
637-
If a list of ints is passed, each entry is assumed to be the timestep
638-
for the corresponding trajectory in x.
632+
List of time points for different trajectories. If a list of ints
633+
is passed, each entry is assumed to be the timestep for the
634+
corresponding trajectory in x. If np.ndarray is passed, it is
635+
used for each trajectory.
639636
640637
x_dot: list of np.ndarray
641638
List of derivative measurements, with each entry corresponding to a
642639
different trajectory. If None, the derivatives will be approximated
643640
from x.
644641
645-
return_array: boolean, optional (default True)
646-
Whether to return concatenated np.ndarrays.
647-
If False, the outputs will be lists with an entry for each trajectory.
648-
649642
Returns
650643
-------
651644
x_out: np.ndarray or list
@@ -658,51 +651,17 @@ def _process_multiple_trajectories(self, x, t, x_dot):
658651
will be an np.ndarray of concatenated trajectories.
659652
If False, x_out will be a list.
660653
"""
661-
if not isinstance(x, Sequence):
662-
raise TypeError("Input x must be a list")
663-
664-
if self.discrete_time:
665-
x = [validate_input(xi) for xi in x]
666-
if x_dot is None:
654+
if x_dot is None:
655+
if self.discrete_time:
667656
x_dot = [xi[1:] for xi in x]
668657
x = [xi[:-1] for xi in x]
669658
else:
670-
if not isinstance(x_dot, Sequence):
671-
raise TypeError(
672-
"x_dot must be a list if used with x of list type "
673-
"(i.e. for multiple trajectories)"
674-
)
675-
x_dot = [validate_input(xd) for xd in x_dot]
676-
else:
677-
if x_dot is None:
678-
x = [
679-
self.feature_library.validate_input(xi, ti)
680-
for xi, ti in _zip_like_sequence(x, t)
681-
]
682659
x_dot = [
683660
self.feature_library.calc_trajectory(
684661
self.differentiation_method, xi, ti
685662
)
686663
for xi, ti in _zip_like_sequence(x, t)
687664
]
688-
else:
689-
if not isinstance(x_dot, Sequence):
690-
raise TypeError(
691-
"x_dot must be a list if used with x of list type "
692-
"(i.e. for multiple trajectories)"
693-
)
694-
if isinstance(t, Sequence):
695-
x = [
696-
self.feature_library.validate_input(xi, ti)
697-
for xi, ti in zip(x, t)
698-
]
699-
x_dot = [
700-
self.feature_library.validate_input(xd, ti)
701-
for xd, ti in zip(x_dot, t)
702-
]
703-
else:
704-
x = [self.feature_library.validate_input(xi, t) for xi in x]
705-
x_dot = [self.feature_library.validate_input(xd, t) for xd in x_dot]
706665
return x, x_dot
707666

708667
def differentiate(self, x, t=None, multiple_trajectories=False):

test/test_pysindy.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -580,21 +580,6 @@ def test_complexity(data_lorenz):
580580
assert model.complexity < 10
581581

582582

583-
def test_multiple_trajectories_errors(data_multiple_trajctories, data_discrete_time):
584-
x, t = data_multiple_trajctories
585-
586-
model = SINDy()
587-
with pytest.raises(TypeError):
588-
model._process_multiple_trajectories(np.array(x, dtype=object), t, x)
589-
with pytest.raises(TypeError):
590-
model._process_multiple_trajectories(x, t, np.array(x, dtype=object))
591-
592-
x = data_discrete_time
593-
model = SINDy(discrete_time=True)
594-
with pytest.raises(TypeError):
595-
model._process_multiple_trajectories(x, t, np.array(x, dtype=object))
596-
597-
598583
def test_simulate_errors(data_lorenz):
599584
x, t = data_lorenz
600585
model = SINDy()

0 commit comments

Comments
 (0)