Skip to content

Commit 60f520c

Browse files
Added sample weight to SINDy fit method.
To do: account for WEAK sindy
1 parent dcf84dc commit 60f520c

File tree

3 files changed

+84
-4
lines changed

3 files changed

+84
-4
lines changed

pysindy/optimizers/stlsq.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,15 @@ class STLSQ(BaseOptimizer):
7878
history_ : list
7979
History of ``coef_``. ``history_[k]`` contains the values of
8080
``coef_`` at iteration k of sequentially thresholded least-squares.
81+
82+
83+
Notes
84+
-----
85+
- Supports ``sample_weight`` during :meth:`fit`. Sample weights are applied
86+
by rescaling rows of the regression problem (X, y) before column
87+
normalization and thresholding. This allows weighted least squares
88+
formulations in SINDy.
89+
- When ``sample_weight`` is not provided, all samples are treated equally.
8190
8291
Examples
8392
--------

pysindy/pysindy.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ def __init__(
294294
differentiation_method = FiniteDifference(axis=-2)
295295
self.differentiation_method = differentiation_method
296296
self.discrete_time = discrete_time
297+
self.set_fit_request(sample_weight=True)
298+
self.set_score_request(sample_weight=True)
297299

298300
def fit(
299301
self,
@@ -302,6 +304,7 @@ def fit(
302304
x_dot=None,
303305
u=None,
304306
feature_names: Optional[list[str]] = None,
307+
sample_weight=None,
305308
):
306309
"""
307310
Fit a SINDy model.
@@ -342,6 +345,11 @@ def fit(
342345
feature_names : list of string, length n_input_features, optional
343346
Names for the input features (e.g. :code:`['x', 'y', 'z']`).
344347
If None, will use :code:`['x0', 'x1', ...]`.
348+
349+
sample_weight : float or array-like of shape (n_samples,), optional
350+
Per-sample weights for the regression. Passed internally to
351+
the optimizer (e.g. STLSQ). Supports compatibility with
352+
scikit-learn tools such as GridSearchCV when using weighted data.
345353
346354
Returns
347355
-------
@@ -371,14 +379,18 @@ def fit(
371379

372380
self.feature_names = feature_names
373381

382+
# User may give one weight per trajectory or one weight per sample
383+
if sample_weight is not None:
384+
sample_weight = _expand_sample_weights(sample_weight, x)
385+
374386
steps = [
375387
("features", self.feature_library),
376388
("shaping", SampleConcatter()),
377389
("model", self.optimizer),
378390
]
379391
x_dot = concat_sample_axis(x_dot)
380392
self.model = Pipeline(steps)
381-
self.model.fit(x, x_dot)
393+
self.model.fit(x, x_dot, model__sample_weight=sample_weight)
382394
self._fit_shape()
383395

384396
return self
@@ -412,6 +424,7 @@ def predict(self, x, u=None):
412424
x, _, u = _comprehend_and_validate_inputs(x, 1, None, u, self.feature_library)
413425

414426
check_is_fitted(self, "model")
427+
415428
if self.n_control_features_ > 0 and u is None:
416429
raise TypeError("Model was fit using control variables, so u is required")
417430
if self.n_control_features_ == 0 and u is not None:
@@ -467,7 +480,7 @@ def print(self, lhs=None, precision=3, **kwargs):
467480
names = f"{lhs[i]}"
468481
print(f"{names} = {eqn}", **kwargs)
469482

470-
def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
483+
def score(self, x, t, x_dot=None, u=None, metric=r2_score, sample_weight=None, **metric_kws):
471484
"""
472485
Returns a score for the time derivative prediction produced by the model.
473486
@@ -500,9 +513,14 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
500513
See `Scikit-learn \
501514
<https://scikit-learn.org/stable/modules/model_evaluation.html>`_
502515
for more options.
516+
517+
sample_weight : array-like of shape (n_samples,), optional
518+
Per-sample weights passed directly to the metric. This is the
519+
preferred way to supply weights.
503520
504521
metric_kws: dict, optional
505522
Optional keyword arguments to pass to the metric function.
523+
506524
507525
Returns
508526
-------
@@ -523,10 +541,21 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
523541

524542
x, x_dot = self._process_trajectories(x, t, x_dot)
525543

544+
if sample_weight is not None:
545+
sample_weight = _expand_sample_weights(sample_weight, x)
546+
526547
x_dot = concat_sample_axis(x_dot)
527548
x_dot_predict = concat_sample_axis(x_dot_predict)
528549

529-
x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict)
550+
if sample_weight is not None:
551+
x_dot, x_dot_predict, good_idx = drop_nan_samples(
552+
x_dot, x_dot_predict, return_indices=True
553+
)
554+
sample_weight = sample_weight[good_idx]
555+
metric_kws = {**metric_kws, "sample_weight": sample_weight}
556+
else:
557+
x_dot, x_dot_predict = drop_nan_samples(x_dot, x_dot_predict)
558+
530559
return metric(x_dot, x_dot_predict, **metric_kws)
531560

532561
def _process_trajectories(self, x, t, x_dot):
@@ -910,3 +939,43 @@ def comprehend_and_validate(arr, t):
910939
)
911940
u = [comprehend_and_validate(ui, ti) for ui, ti in _zip_like_sequence(u, t)]
912941
return x, x_dot, u
942+
943+
def _expand_sample_weights(sample_weight, trajectories):
944+
"""Expand trajectory-level weights to per-sample weights.
945+
946+
Parameters
947+
----------
948+
sample_weight : array-like of shape (n_trajectories,) or (n_samples,), default=None
949+
If length == n_trajectories, each trajectory weight is expanded to cover
950+
all samples in that trajectory.
951+
If length == n_samples, interpreted as per-sample weights directly.
952+
If None, uniform weighting is applied.
953+
954+
trajectories : list of array-like
955+
The list of input trajectories, each shape (n_samples_i, n_features).
956+
957+
Returns
958+
-------
959+
sample_weight : ndarray of shape (sum_i n_samples_i,)
960+
Per-sample weights, ready to use in metrics.
961+
"""
962+
if sample_weight is None:
963+
return None
964+
965+
sample_weight = np.asarray(sample_weight)
966+
967+
n_traj = len(trajectories)
968+
n_samples_total = sum(len(traj) for traj in trajectories)
969+
970+
if sample_weight.ndim == 1 and len(sample_weight) == n_traj:
971+
# Efficient expansion using np.repeat
972+
traj_lengths = [len(traj) for traj in trajectories]
973+
return np.repeat(sample_weight, traj_lengths)
974+
975+
if sample_weight.ndim == 1 and len(sample_weight) == n_samples_total:
976+
return sample_weight
977+
978+
raise ValueError(
979+
f"sample_weight must be length {n_traj} (per trajectory) or "
980+
f"{n_samples_total} (per sample), got {len(sample_weight)}"
981+
)

pysindy/utils/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _check_control_shape(x, u, trim_last_point):
128128
return u_arr
129129

130130

131-
def drop_nan_samples(x, y):
131+
def drop_nan_samples(x, y, return_indices: bool = False):
132132
"""Drops samples from x and y where either has a nan value"""
133133
x_non_sample_axes = tuple(ax for ax in range(x.ndim) if ax != x.ax_sample)
134134
y_non_sample_axes = tuple(ax for ax in range(y.ndim) if ax != y.ax_sample)
@@ -137,6 +137,8 @@ def drop_nan_samples(x, y):
137137
good_sample_ind = np.nonzero(x_good_samples & y_good_samples)[0]
138138
x = x.take(good_sample_ind, axis=x.ax_sample)
139139
y = y.take(good_sample_ind, axis=y.ax_sample)
140+
if return_indices:
141+
return x, y, good_sample_ind
140142
return x, y
141143

142144

0 commit comments

Comments
 (0)