Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 164 additions & 3 deletions pysindy/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
from scipy.integrate import odeint
from scipy.integrate import solve_ivp
from scipy.interpolate import interp1d
from sklearn import set_config
from sklearn.base import BaseEstimator
from sklearn.metrics import r2_score
from sklearn.pipeline import Pipeline
from sklearn.utils.validation import check_is_fitted

set_config(enable_metadata_routing=True)
from typing_extensions import Self

from .differentiation import BaseDifferentiation
Expand Down Expand Up @@ -293,6 +296,9 @@ def __init__(
differentiation_method = FiniteDifference(axis=-2)
self.differentiation_method = differentiation_method
self.discrete_time = discrete_time
self.set_fit_request(sample_weight=True)
self.set_score_request(sample_weight=True)
self.optimizer.set_fit_request(sample_weight=True)

def fit(
self,
Expand All @@ -301,6 +307,7 @@ def fit(
x_dot=None,
u=None,
feature_names: Optional[list[str]] = None,
sample_weight=None,
):
"""
Fit a SINDy model.
Expand Down Expand Up @@ -342,6 +349,11 @@ def fit(
Names for the input features (e.g. :code:`['x', 'y', 'z']`).
If None, will use :code:`['x0', 'x1', ...]`.

sample_weight : float or array-like of shape (n_samples,), optional
Per-sample weights for the regression. Passed internally to
the optimizer (e.g. STLSQ). Supports compatibility with
scikit-learn tools such as GridSearchCV when using weighted data.

Returns
-------
self: a fitted :class:`SINDy` instance
Expand Down Expand Up @@ -370,14 +382,24 @@ def fit(

self.feature_names = feature_names

if sample_weight is not None:
mode = (
"weak"
if "Weak" in self.feature_library.__class__.__name__
else "standard"
)
sample_weight = _expand_sample_weights(
sample_weight, x, feature_library=self.feature_library, mode=mode
)

steps = [
("features", self.feature_library),
("shaping", SampleConcatter()),
("model", self.optimizer),
]
x_dot = concat_sample_axis(x_dot)
self.model = Pipeline(steps)
self.model.fit(x, x_dot)
self.model.fit(x, x_dot, sample_weight=sample_weight)
self._fit_shape()

return self
Expand Down Expand Up @@ -411,6 +433,7 @@ def predict(self, x, u=None):
x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library)

check_is_fitted(self, "model")

if self.n_control_features_ > 0 and u is None:
raise TypeError("Model was fit using control variables, so u is required")
if self.n_control_features_ == 0 and u is not None:
Expand Down Expand Up @@ -466,7 +489,16 @@ def print(self, lhs=None, precision=3, **kwargs):
names = f"{lhs[i]}"
print(f"{names} = {eqn}", **kwargs)

def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
def score(
self,
x,
t,
x_dot=None,
u=None,
metric=r2_score,
sample_weight=None,
**metric_kws,
):
"""
Returns a score for the time derivative prediction produced by the model.

Expand Down Expand Up @@ -500,9 +532,14 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
<https://scikit-learn.org/stable/modules/model_evaluation.html>`_
for more options.

sample_weight : array-like of shape (n_samples,), optional
Per-sample weights passed directly to the metric. This is the
preferred way to supply weights.

metric_kws: dict, optional
Optional keyword arguments to pass to the metric function.


Returns
-------
score: float
Expand All @@ -522,10 +559,21 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):

x, x_dot = self._process_trajectories(x, t, x_dot)

if sample_weight is not None:
sample_weight = _expand_sample_weights(sample_weight, x)

x_dot = concat_sample_axis(x_dot)
x_dot_predict = concat_sample_axis(x_dot_predict)

x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict)
if sample_weight is not None:
x_dot, x_dot_predict, good_idx = drop_nan_samples(
x_dot, x_dot_predict, return_indices=True
)
sample_weight = sample_weight[good_idx]
metric_kws = {**metric_kws, "sample_weight": sample_weight}
else:
x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict)

return metric(x_dot, x_dot_predict, **metric_kws)

def _process_trajectories(self, x, t, x_dot):
Expand Down Expand Up @@ -909,3 +957,116 @@ def comprehend_and_validate(arr, t):
)
u = [comprehend_and_validate(ui, ti) for ui, ti in _zip_like_sequence(u, t)]
return x, x_dot, u


def _expand_sample_weights(
sample_weight, trajectories, feature_library=None, mode="standard"
):
"""
Expand per-trajectory or per-sample weights for use in SINDy estimators.

Parameters
----------
sample_weight : sequence of scalars or array-like
Weights for each trajectory. In "standard" mode, each entry can be:
- a scalar weight (applied to all samples in that trajectory), or
- an array of length equal to the number of samples (n_time) for that
trajectory.
In "weak" mode, each entry must be a single scalar weight per trajectory.

trajectories : sequence
Sequence of trajectory-like objects, each having attributes `n_time` and
`n_coord`.

feature_library : object, optional
Library instance used in weak-form mode. Must define attribute `K`
(the number of weak test functions). If missing, assumes K=1 with a warning.

mode : {'standard', 'weak'}, default='standard'
- "standard": Expand per-sample weights to match concatenated samples.
- "weak": Repeat each trajectory’s single scalar weight `K` times.

Returns
-------
np.ndarray or None
A 1D numpy array of concatenated and expanded sample weights,
or None if `sample_weight` is None.
"""
# -------------------------------------------------------------
# Early exit for None
# -------------------------------------------------------------
if sample_weight is None:
return None

if not (
isinstance(sample_weight, Sequence)
and not isinstance(sample_weight, np.ndarray)
):
raise ValueError(
"sample_weight must be a list or tuple, not a scalar or numpy array."
)

if len(sample_weight) != len(trajectories):
raise ValueError("sample_weight length must match number of trajectories.")

# -------------------------------------------------------------
# Weak mode: one weight per trajectory, repeated K times
# -------------------------------------------------------------
if mode == "weak":
if feature_library is None:
raise ValueError("feature_library is required in weak mode.")

K = getattr(feature_library, "K", None)
if K is None:
warnings.warn("feature_library missing 'K'; assuming K=1.", UserWarning)
K = 1

validated = []
for w, traj in zip(sample_weight, trajectories):
arr = np.asarray(w)
if arr.ndim > 0 and arr.size > 1:
raise ValueError(
"Weak mode expects exactly one weight per trajectory (scalar), "
f"but got shape {arr.shape} for trajectory with {traj.n_time}"
f"samples."
)
validated.append(float(arr))
return np.repeat(validated, K)

# -------------------------------------------------------------
# Standard mode: expand scalars or per-sample arrays
# -------------------------------------------------------------
expanded = []
for w, traj in zip(sample_weight, trajectories):
arr = np.asarray(w)

# Scalar → expand to all samples in trajectory
if arr.ndim == 0:
arr = np.full(traj.n_time, arr, dtype=float)

# 1D array → must match number of samples
elif arr.ndim == 1:
if arr.shape[0] != traj.n_time:
raise ValueError(
f"sample_weight length {arr.shape[0]} does"
f" not match trajectory length {traj.n_time}."
)

# 2D array → only (n,1) allowed
elif arr.ndim == 2:
if arr.shape[1] != 1:
raise ValueError(
"sample_weight 2D arrays must have second dimension = 1."
)
if arr.shape[0] != traj.n_time:
raise ValueError(
"sample_weight 2D array length does not match trajectory length."
)
arr = arr.ravel()

else:
raise ValueError("Invalid sample_weight shape.")

expanded.append(arr.ravel())

return np.concatenate(expanded)
2 changes: 2 additions & 0 deletions pysindy/feature_library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .polynomial_library import PolynomialLibrary
from .sindy_pi_library import SINDyPILibrary
from .weak_pde_library import WeakPDELibrary
from .weighted_weak_pde_library import WeightedWeakPDELibrary

__all__ = [
"ConcatLibrary",
Expand All @@ -21,6 +22,7 @@
"PolynomialLibrary",
"PDELibrary",
"WeakPDELibrary",
"WeightedWeakPDELibrary",
"SINDyPILibrary",
"ParameterizedLibrary",
"base",
Expand Down
Loading
Loading