@@ -620,11 +620,7 @@ def score(
620620
621621 def _process_multiple_trajectories (self , x , t , x_dot ):
622622 """
623- Handle input data that contains multiple trajectories by doing the
624- necessary validation, reshaping, and computation of derivatives.
625-
626- This method essentially just loops over elements of each list in parallel,
627- validates them, and (optionally) concatenates them together.
623+ Calculate derivatives of input data, iterating through trajectories.
628624
629625 Parameters
630626 ----------
@@ -633,19 +629,16 @@ def _process_multiple_trajectories(self, x, t, x_dot):
633629 trajectory.
634630
635631 t: list of np.ndarray or int
636- List of time points for different trajectories.
637- If a list of ints is passed, each entry is assumed to be the timestep
638- for the corresponding trajectory in x.
632+ List of time points for different trajectories. If a list of ints
633+ is passed, each entry is assumed to be the timestep for the
634+ corresponding trajectory in x. If np.ndarray is passed, it is
635+ used for each trajectory.
639636
640637 x_dot: list of np.ndarray
641638 List of derivative measurements, with each entry corresponding to a
642639 different trajectory. If None, the derivatives will be approximated
643640 from x.
644641
645- return_array: boolean, optional (default True)
646- Whether to return concatenated np.ndarrays.
647- If False, the outputs will be lists with an entry for each trajectory.
648-
649642 Returns
650643 -------
651644 x_out: np.ndarray or list
@@ -658,51 +651,17 @@ def _process_multiple_trajectories(self, x, t, x_dot):
658651 will be an np.ndarray of concatenated trajectories.
659652 If False, x_out will be a list.
660653 """
661- if not isinstance (x , Sequence ):
662- raise TypeError ("Input x must be a list" )
663-
664- if self .discrete_time :
665- x = [validate_input (xi ) for xi in x ]
666- if x_dot is None :
654+ if x_dot is None :
655+ if self .discrete_time :
667656 x_dot = [xi [1 :] for xi in x ]
668657 x = [xi [:- 1 ] for xi in x ]
669658 else :
670- if not isinstance (x_dot , Sequence ):
671- raise TypeError (
672- "x_dot must be a list if used with x of list type "
673- "(i.e. for multiple trajectories)"
674- )
675- x_dot = [validate_input (xd ) for xd in x_dot ]
676- else :
677- if x_dot is None :
678- x = [
679- self .feature_library .validate_input (xi , ti )
680- for xi , ti in _zip_like_sequence (x , t )
681- ]
682659 x_dot = [
683660 self .feature_library .calc_trajectory (
684661 self .differentiation_method , xi , ti
685662 )
686663 for xi , ti in _zip_like_sequence (x , t )
687664 ]
688- else :
689- if not isinstance (x_dot , Sequence ):
690- raise TypeError (
691- "x_dot must be a list if used with x of list type "
692- "(i.e. for multiple trajectories)"
693- )
694- if isinstance (t , Sequence ):
695- x = [
696- self .feature_library .validate_input (xi , ti )
697- for xi , ti in zip (x , t )
698- ]
699- x_dot = [
700- self .feature_library .validate_input (xd , ti )
701- for xd , ti in zip (x_dot , t )
702- ]
703- else :
704- x = [self .feature_library .validate_input (xi , t ) for xi in x ]
705- x_dot = [self .feature_library .validate_input (xd , t ) for xd in x_dot ]
706665 return x , x_dot
707666
708667 def differentiate (self , x , t = None , multiple_trajectories = False ):
0 commit comments