Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pysindy/feature_library/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jax
import numpy as np
from scipy import sparse
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.utils.validation import check_is_fitted

Expand All @@ -20,14 +21,17 @@
from ..utils import wrap_axes


class BaseFeatureLibrary(TransformerMixin):
class BaseFeatureLibrary(TransformerMixin, BaseEstimator):
"""
Base class for feature libraries.

Forces subclasses to implement ``fit``, ``transform``,
and ``get_feature_names`` methods.
"""

n_features_in_: int
n_output_features_: int

def validate_input(self, x, *args, **kwargs):
return validate_no_reshape(x, *args, **kwargs)

Expand Down Expand Up @@ -109,7 +113,7 @@ def transform(self, x):

# Force subclasses to implement this
@abc.abstractmethod
def get_feature_names(self, input_features=None):
def get_feature_names(self, input_features=None) -> list[str]:
"""Return feature names for output features.

Parameters
Expand Down
2 changes: 1 addition & 1 deletion pysindy/feature_library/polynomial_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .base import x_sequence_or_item


class PolynomialLibrary(PolynomialFeatures, BaseFeatureLibrary):
class PolynomialLibrary(BaseFeatureLibrary, PolynomialFeatures):
"""Generate polynomial and interaction features.

This is the same as :code:`sklearn.preprocessing.PolynomialFeatures`,
Expand Down
39 changes: 22 additions & 17 deletions pysindy/feature_library/weak_pde_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@
self.num_trajectories = 1
self.differentiation_method = differentiation_method
self.diff_kwargs = diff_kwargs
if function_library is None:
self.multiindices = multiindices
self.spatiotemporal_grid = spatiotemporal_grid
if self.function_library is None:
self.function_library = PolynomialLibrary(degree=3, include_bias=False)

if spatiotemporal_grid is None:
Expand All @@ -177,14 +179,22 @@
"in favor of differetiation_method and diff_kwargs.",
UserWarning,
)
# Weak form checks and setup
self._weak_form_setup()

def set_params(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
self._weak_form_setup()

Check warning on line 188 in pysindy/feature_library/weak_pde_library.py

View check run for this annotation

Codecov / codecov/patch

pysindy/feature_library/weak_pde_library.py#L186-L188

Added lines #L186 - L188 were not covered by tests

def _weak_form_setup(self):
# list of integrals
indices = ()
if np.array(spatiotemporal_grid).ndim == 1:
spatiotemporal_grid = np.reshape(
spatiotemporal_grid, (len(spatiotemporal_grid), 1)
if np.array(self.spatiotemporal_grid).ndim == 1:
self.spatiotemporal_grid = np.reshape(
self.spatiotemporal_grid, (len(self.spatiotemporal_grid), 1)
)
dims = spatiotemporal_grid.shape[:-1]
dims = self.spatiotemporal_grid.shape[:-1]
self.grid_dims = dims
self.grid_ndim = len(dims)

Expand All @@ -195,30 +205,25 @@
self.ind_range = len(dims) - 1

for i in range(self.ind_range):
indices = indices + (range(derivative_order + 1),)
indices = indices + (range(self.derivative_order + 1),)

if multiindices is None:
if self.multiindices is None:
multiindices = []
for ind in iproduct(*indices):
current = np.array(ind)
if np.sum(ind) > 0 and np.sum(ind) <= derivative_order:
if np.sum(ind) > 0 and np.sum(ind) <= self.derivative_order:
multiindices.append(current)
multiindices = np.array(multiindices)
num_derivatives = len(multiindices)
self.multiindices = np.array(multiindices)
num_derivatives = len(self.multiindices)
if num_derivatives > 0:
self.derivative_order = np.max(multiindices)
self.derivative_order = np.max(self.multiindices)

self.num_derivatives = num_derivatives
self.multiindices = multiindices

self.spatiotemporal_grid = AxesArray(
spatiotemporal_grid, axes=comprehend_axes(spatiotemporal_grid)
self.spatiotemporal_grid, axes=comprehend_axes(self.spatiotemporal_grid)
)

# Weak form checks and setup
self._weak_form_setup()

def _weak_form_setup(self):
xt1, xt2 = self._get_spatial_endpoints()
L_xt = xt2 - xt1
if self.H_xt is not None:
Expand Down
59 changes: 40 additions & 19 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from abc import ABC
from abc import abstractmethod
from itertools import product
from typing import Collection
from typing import Optional
from typing import Sequence
from typing import TypeVar
from typing import Union

import numpy as np
Expand Down Expand Up @@ -42,19 +42,36 @@
from .utils import validate_no_reshape


TrajectoryType = TypeVar("TrajectoryType", list[np.ndarray], np.ndarray)


class _BaseSINDy(BaseEstimator, ABC):

feature_library: BaseFeatureLibrary
optimizer: _BaseOptimizer
discrete_time: bool
model: Pipeline
feature_names: Optional[list[str]]
# Hacks to remove later
feature_names: Optional[list[str]]
discrete_time: bool = False
n_control_features_: int = 0

@abstractmethod
def fit(self, x, t, *args, **kwargs) -> Self:
def fit(self, x: TrajectoryType, t: TrajectoryType, *args, **kwargs) -> Self:
...

Check warning on line 61 in pysindy/pysindy.py

View check run for this annotation

Codecov / codecov/patch

pysindy/pysindy.py#L61

Added line #L61 was not covered by tests

@abstractmethod
def predict(self, x: np.ndarray) -> np.ndarray:
...

Check warning on line 65 in pysindy/pysindy.py

View check run for this annotation

Codecov / codecov/patch

pysindy/pysindy.py#L65

Added line #L65 was not covered by tests

@abstractmethod
def simulate(self, x0: np.ndarray, t: np.ndarray) -> np.ndarray:
...

Check warning on line 69 in pysindy/pysindy.py

View check run for this annotation

Codecov / codecov/patch

pysindy/pysindy.py#L69

Added line #L69 was not covered by tests

@abstractmethod
def score(
self, x: TrajectoryType, t: TrajectoryType, x_dot: TrajectoryType
) -> float:
...

def _fit_shape(self):
Expand All @@ -69,6 +86,19 @@
feature_names.append("u" + str(i))
self.feature_names = feature_names

def coefficients(self):
"""
Get an array of the coefficients learned by SINDy model.

Returns
-------
coef: np.ndarray, shape (n_input_features, n_output_features)
Learned coefficients of the SINDy model.
Equivalent to :math:`\\Xi^\\top` in the literature.
"""
check_is_fitted(self)
return self.optimizer.coef_

def equations(self, precision: int = 3) -> list[str]:
"""
Get the right hand sides of the SINDy model equations.
Expand Down Expand Up @@ -128,7 +158,7 @@
lhs = f"({name})'"
print(f"{lhs} = {eqn}", **kwargs)

def get_feature_names(self):
def get_feature_names(self) -> list[str]:
"""
Get a list of names of features used by SINDy model.

Expand Down Expand Up @@ -609,19 +639,6 @@
return result[0]
return result

def coefficients(self):
"""
Get an array of the coefficients learned by SINDy model.

Returns
-------
coef: np.ndarray, shape (n_input_features, n_output_features)
Learned coefficients of the SINDy model.
Equivalent to :math:`\\Xi^\\top` in the literature.
"""
check_is_fitted(self, "model")
return self.optimizer.coef_

def simulate(
self,
x0,
Expand Down Expand Up @@ -792,7 +809,10 @@


def _zip_like_sequence(x, t):
"""Create an iterable like zip(x, t), but works if t is scalar."""
"""Create an iterable like zip(x, t), but works if t is scalar.

If t is an array, it is repeated for each x
"""
if isinstance(t, Sequence):
return zip(x, t)
else:
Expand Down Expand Up @@ -870,7 +890,8 @@
Tuple of updated x, t, x_dot, u
"""
x = [x]
if isinstance(t, Collection):
# if t is not a dt
if not isinstance(t, np.ScalarType):
t = [t]
if x_dot is not None:
x_dot = [x_dot]
Expand Down
Loading