77from sklearn .utils .validation import check_is_fitted
88
99from ..utils import AxesArray
10+ from ..utils import comprehend_axes
1011from .base import BaseFeatureLibrary
1112from .base import x_sequence_or_item
1213from pysindy .differentiation import FiniteDifference
@@ -197,10 +198,6 @@ def __init__(
197198 spatiotemporal_grid = np .reshape (
198199 spatiotemporal_grid , (len (spatiotemporal_grid ), 1 )
199200 )
200- self .grid_dims = spatiotemporal_grid .shape [:- 1 ]
201- self .spatial_grid_dims = self .spatial_grid .shape [:- 1 ]
202- self .grid_ndim = len (spatiotemporal_grid .shape [:- 1 ])
203- self .spatial_grid_ndim = len (self .spatial_grid_dims )
204201
205202 # if want to include temporal terms -> range(len(dims))
206203 if self .implicit_terms :
@@ -222,7 +219,9 @@ def __init__(
222219
223220 self .num_derivatives = num_derivatives
224221 self .multiindices = multiindices
225- self .spatiotemporal_grid = spatiotemporal_grid
222+ self .spatiotemporal_grid = AxesArray (
223+ spatiotemporal_grid , comprehend_axes (spatiotemporal_grid )
224+ )
226225
227226 @staticmethod
228227 def _combinations (n_features , n_args , interaction_only ):
@@ -276,7 +275,13 @@ def get_feature_names(self, input_features=None):
276275 def derivative_string (multiindex ):
277276 ret = ""
278277 for axis in range (self .ind_range ):
279- if (axis == self .grid_ndim - 1 ) and self .implicit_terms :
278+ if self .implicit_terms and (
279+ axis
280+ in [
281+ self .spatiotemporal_grid .ax_time ,
282+ self .spatiotemporal_grid .ax_sample ,
283+ ]
284+ ):
280285 str_deriv = "t"
281286 else :
282287 str_deriv = str (axis + 1 )
@@ -392,11 +397,11 @@ def transform(self, x_full):
392397 library_derivatives = np .empty (shape , dtype = x .dtype )
393398 library_idx = 0
394399 for multiindex in self .multiindices :
395- derivs = x . copy ()
400+ derivs = x
396401 for axis in range (self .ind_range ):
397402 if multiindex [axis ] > 0 :
398403 s = [0 for dim in self .spatiotemporal_grid .shape ]
399- s [axis ] = slice (self .grid_dims [axis ])
404+ s [axis ] = slice (self .spatiotemporal_grid . shape [axis ])
400405 s [- 1 ] = axis
401406
402407 derivs = FiniteDifference (
@@ -459,7 +464,8 @@ def transform(self, x_full):
459464 shape ,
460465 )
461466 library_idx += n_library_terms * self .num_derivatives * n_features
462- xp_full = xp_full + [AxesArray (xp , self .comprehend_axes (xp ))]
467+ xp = AxesArray (xp , comprehend_axes (xp ))
468+ xp_full .append (xp )
463469 if self .library_ensemble :
464470 xp_full = self ._ensemble (xp_full )
465471 return xp_full
0 commit comments