Skip to content

Commit 7e02c31

Browse files
Merge branch 'master' into cln_process_multiple
2 parents 01b92ff + 632585a commit 7e02c31

File tree

18 files changed

+124
-244
lines changed

18 files changed

+124
-244
lines changed

pysindy/differentiation/sindy_derivative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def set_params(self, **params):
5151
# Simple optimization to gain speed (inspect is slow)
5252
return self
5353
else:
54-
self.kwargs.update(params)
54+
self.kwargs.update(params["kwargs"])
5555

5656
return self
5757

pysindy/feature_library/base.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
from sklearn.utils.validation import check_is_fitted
1414

1515
from ..utils import AxesArray
16-
from ..utils import DefaultShapedInputsMixin
16+
from ..utils import comprehend_axes
1717
from ..utils import validate_no_reshape
1818
from ..utils import wrap_axes
19-
from ..utils.axes import ax_time_to_ax_sample
2019

2120

22-
class BaseFeatureLibrary(DefaultShapedInputsMixin, TransformerMixin):
21+
class BaseFeatureLibrary(TransformerMixin):
2322
"""
2423
Base class for feature libraries.
2524
@@ -187,17 +186,12 @@ def func(self, x, *args, **kwargs):
187186
return wrapped_func(self, x, *args, **kwargs)
188187
else:
189188
if not sparse.issparse(x):
190-
x = AxesArray(x, self.comprehend_axes(x))
191-
x = ax_time_to_ax_sample(x)
189+
x = AxesArray(x, comprehend_axes(x))
192190
reconstructor = np.array
193191
else: # sparse arrays
194192
reconstructor = type(x)
195-
axes = self.comprehend_axes(x)
196-
wrap_axes(axes)(x)
197-
# Can't use x = ax_time_to_ax_sample(x) b/c that creates
198-
# an AxesArray
199-
x.ax_sample = x.ax_time
200-
x.ax_time = None
193+
axes = comprehend_axes(x)
194+
wrap_axes(axes, x)
201195
result = wrapped_func(self, [x], *args, **kwargs)
202196
if isinstance(result, Sequence): # e.g. transform() returns x
203197
return reconstructor(result[0])
@@ -332,8 +326,8 @@ def transform(self, x_full):
332326
xp[..., start_feature_index:end_feature_index] = lib.transform([x])[0]
333327

334328
current_feat += lib_n_output_features
335-
336-
xp_full = xp_full + [AxesArray(xp, self.comprehend_axes(xp))]
329+
xp = AxesArray(xp, comprehend_axes(xp))
330+
xp_full.append(xp)
337331
if self.library_ensemble:
338332
xp_full = self._ensemble(xp_full)
339333
return xp_full
@@ -565,7 +559,8 @@ def transform(self, x_full):
565559

566560
current_feat += lib_i_n_output_features * lib_j_n_output_features
567561

568-
xp_full = xp_full + [AxesArray(xp, self.comprehend_axes(xp))]
562+
xp = AxesArray(xp, comprehend_axes(xp))
563+
xp_full.append(xp)
569564
if self.library_ensemble:
570565
xp_full = self._ensemble(xp_full)
571566
return xp_full

pysindy/feature_library/custom_library.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sklearn.utils.validation import check_is_fitted
99

1010
from ..utils import AxesArray
11+
from ..utils import comprehend_axes
1112
from .base import BaseFeatureLibrary
1213
from .base import x_sequence_or_item
1314

@@ -229,7 +230,8 @@ def transform(self, x_full):
229230
xp[..., library_idx] = f(*[x[..., j] for j in c])
230231
library_idx += 1
231232

232-
xp_full = xp_full + [AxesArray(xp, self.comprehend_axes(xp))]
233+
xp = AxesArray(xp, comprehend_axes(xp))
234+
xp_full.append(xp)
233235
if self.library_ensemble:
234236
xp_full = self._ensemble(xp_full)
235237
return xp_full

pysindy/feature_library/fourier_library.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from sklearn.utils.validation import check_is_fitted
44

55
from ..utils import AxesArray
6+
from ..utils import comprehend_axes
67
from .base import BaseFeatureLibrary
78
from .base import x_sequence_or_item
89

@@ -173,8 +174,8 @@ def transform(self, x_full):
173174
if self.include_cos:
174175
xp[..., idx] = np.cos((i + 1) * x[..., j])
175176
idx += 1
176-
177-
xp_full = xp_full + [AxesArray(xp, self.comprehend_axes(xp))]
177+
xp = AxesArray(xp, comprehend_axes(xp))
178+
xp_full.append(xp)
178179
if self.library_ensemble:
179180
xp_full = self._ensemble(xp_full)
180181
return xp_full

pysindy/feature_library/generalized_library.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Type
2-
31
import numpy as np
42
from sklearn import __version__
53
from sklearn.utils.validation import check_is_fitted
@@ -130,12 +128,10 @@ def __init__(
130128
if weak_libraries:
131129
self.validate_input = libraries[weak_libraries].validate_input
132130
self.calc_trajectory = libraries[weak_libraries].calc_trajectory
133-
self.comprehend_axes = libraries[weak_libraries].comprehend_axes
134131
self.spatiotemporal_grid = libraries[weak_libraries].spatiotemporal_grid
135132
elif pde_libraries:
136133
self.validate_input = libraries[pde_libraries].validate_input
137134
self.calc_trajectory = libraries[pde_libraries].calc_trajectory
138-
self.comprehend_axes = libraries[pde_libraries].comprehend_axes
139135
self.spatial_grid = libraries[pde_libraries].spatial_grid
140136
else:
141137
raise ValueError(
@@ -246,23 +242,6 @@ def fit(self, x_full, y=None):
246242

247243
return self
248244

249-
def has_type(self, libtype: Type, exclusively=False) -> bool:
250-
"""Checks whether this library has a specific library type.
251-
252-
Parameters
253-
----------
254-
libtype : A type of feature library
255-
exclusively: whether to check all libraries
256-
257-
Returns
258-
-------
259-
Bool indicating whether specific library type is present
260-
"""
261-
has_inst = map(lambda lib: isinstance(lib, libtype), self.libraries_)
262-
if exclusively:
263-
return all(has_inst)
264-
return any(has_inst)
265-
266245
@x_sequence_or_item
267246
def transform(self, x_full):
268247
"""Transform data with libs provided below.
@@ -283,13 +262,9 @@ def transform(self, x_full):
283262

284263
xp_full = []
285264
for x in x_full:
286-
# n_samples = x.shape[x.ax_sample]
287265
n_features = x.shape[x.ax_coord]
288266
shape = np.array(x.shape)
289267

290-
# if isinstance(self.libraries_[0], WeakPDELibrary):
291-
# n_samples = self.libraries_[0].K * self.libraries_[0].num_trajectories
292-
293268
if float(__version__[:3]) >= 1.0:
294269
n_input_features = self.n_features_in_
295270
else:

pysindy/feature_library/pde_library.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sklearn.utils.validation import check_is_fitted
88

99
from ..utils import AxesArray
10+
from ..utils import comprehend_axes
1011
from .base import BaseFeatureLibrary
1112
from .base import x_sequence_or_item
1213
from pysindy.differentiation import FiniteDifference
@@ -197,10 +198,6 @@ def __init__(
197198
spatiotemporal_grid = np.reshape(
198199
spatiotemporal_grid, (len(spatiotemporal_grid), 1)
199200
)
200-
self.grid_dims = spatiotemporal_grid.shape[:-1]
201-
self.spatial_grid_dims = self.spatial_grid.shape[:-1]
202-
self.grid_ndim = len(spatiotemporal_grid.shape[:-1])
203-
self.spatial_grid_ndim = len(self.spatial_grid_dims)
204201

205202
# if want to include temporal terms -> range(len(dims))
206203
if self.implicit_terms:
@@ -222,7 +219,9 @@ def __init__(
222219

223220
self.num_derivatives = num_derivatives
224221
self.multiindices = multiindices
225-
self.spatiotemporal_grid = spatiotemporal_grid
222+
self.spatiotemporal_grid = AxesArray(
223+
spatiotemporal_grid, comprehend_axes(spatiotemporal_grid)
224+
)
226225

227226
@staticmethod
228227
def _combinations(n_features, n_args, interaction_only):
@@ -276,7 +275,13 @@ def get_feature_names(self, input_features=None):
276275
def derivative_string(multiindex):
277276
ret = ""
278277
for axis in range(self.ind_range):
279-
if (axis == self.grid_ndim - 1) and self.implicit_terms:
278+
if self.implicit_terms and (
279+
axis
280+
in [
281+
self.spatiotemporal_grid.ax_time,
282+
self.spatiotemporal_grid.ax_sample,
283+
]
284+
):
280285
str_deriv = "t"
281286
else:
282287
str_deriv = str(axis + 1)
@@ -392,11 +397,11 @@ def transform(self, x_full):
392397
library_derivatives = np.empty(shape, dtype=x.dtype)
393398
library_idx = 0
394399
for multiindex in self.multiindices:
395-
derivs = x.copy()
400+
derivs = x
396401
for axis in range(self.ind_range):
397402
if multiindex[axis] > 0:
398403
s = [0 for dim in self.spatiotemporal_grid.shape]
399-
s[axis] = slice(self.grid_dims[axis])
404+
s[axis] = slice(self.spatiotemporal_grid.shape[axis])
400405
s[-1] = axis
401406

402407
derivs = FiniteDifference(
@@ -459,7 +464,8 @@ def transform(self, x_full):
459464
shape,
460465
)
461466
library_idx += n_library_terms * self.num_derivatives * n_features
462-
xp_full = xp_full + [AxesArray(xp, self.comprehend_axes(xp))]
467+
xp = AxesArray(xp, comprehend_axes(xp))
468+
xp_full.append(xp)
463469
if self.library_ensemble:
464470
xp_full = self._ensemble(xp_full)
465471
return xp_full

pysindy/feature_library/polynomial_library.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.utils.validation import check_is_fitted
1111

1212
from ..utils import AxesArray
13+
from ..utils import comprehend_axes
1314
from ..utils import wrap_axes
1415
from .base import BaseFeatureLibrary
1516
from .base import x_sequence_or_item
@@ -229,15 +230,11 @@ def transform(self, x_full):
229230
for x in x_full:
230231
if sparse.issparse(x) and x.format not in ["csr", "csc"]:
231232
# create new with correct sparse
232-
axes = self.comprehend_axes(x)
233+
axes = comprehend_axes(x)
233234
x = x.asformat("csr")
234-
wrap_axes(axes)(x)
235-
# Can't use x = ax_time_to_ax_sample(x) b/c that creates
236-
# an AxesArray
237-
x.ax_sample = x.ax_time
238-
x.ax_time = None
235+
wrap_axes(axes, x)
239236

240-
n_samples = x.shape[x.ax_sample]
237+
n_samples = x.shape[x.ax_time]
241238
n_features = x.shape[x.ax_coord]
242239
if float(__version__[:3]) >= 1.0:
243240
if n_features != self.n_features_in_:

pysindy/feature_library/weak_pde_library.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -831,28 +831,6 @@ def transform(self, x_full):
831831
"""
832832
check_is_fitted(self)
833833

834-
# x = check_array(x)
835-
#
836-
# n_samples_original_full, n_features = x.shape
837-
# n_samples_original = n_samples_original_full // self.num_trajectories
838-
#
839-
# if float(__version__[:3]) >= 1.0:
840-
# if n_features != self.n_features_in_:
841-
# raise ValueError("x shape does not match training shape")
842-
# else:
843-
# if n_features != self.n_input_features_:
844-
# raise ValueError("x shape does not match training shape")
845-
#
846-
# if self.spatiotemporal_grid is not None:
847-
# n_samples = self.K
848-
# n_samples_full = self.K * self.num_trajectories
849-
#
850-
# xp_full = np.empty(
851-
# (self.num_trajectories, n_samples, self.n_output_features_), dtype=x.dtype
852-
# )
853-
# x_full = np.reshape(
854-
# x, np.concatenate([[self.num_trajectories], self.grid_dims, [n_features]])
855-
# )
856834
xp_full = []
857835
for x in x_full:
858836
n_features = x.shape[x.ax_coord]
@@ -871,12 +849,6 @@ def transform(self, x_full):
871849
library_functions = np.empty((self.K, n_library_terms), dtype=x.dtype)
872850

873851
# Evaluate the functions on the indices of domain cells
874-
# x_shaped = np.reshape(
875-
# x,
876-
# np.concatenate([self.spatiotemporal_grid.shape[:-1], [x.shape[-1]]]),
877-
# )
878-
# dims = np.array(x_shaped.shape)
879-
# dims[-1] = n_library_terms
880852
funcs = np.zeros((*x.shape[:-1], n_library_terms))
881853
func_idx = 0
882854
for f in self.functions:

pysindy/optimizers/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,5 @@ def _drop_random_samples(
363363
rand_inds = choice(range(n_samples), n_subset, replace=replace)
364364
x_new = np.take(x, rand_inds, axis=x.ax_sample)
365365
x_dot_new = np.take(x_dot, rand_inds, axis=x.ax_sample)
366-
# x_dot_new = np.take(x_dot, rand_inds, axis=x_dot.ax_sample)
367366

368367
return x_new, x_dot_new

0 commit comments

Comments
 (0)