@@ -294,6 +294,8 @@ def __init__(
294294 differentiation_method = FiniteDifference (axis = - 2 )
295295 self .differentiation_method = differentiation_method
296296 self .discrete_time = discrete_time
297+ self .set_fit_request (sample_weight = True )
298+ self .set_score_request (sample_weight = True )
297299
298300 def fit (
299301 self ,
@@ -302,6 +304,7 @@ def fit(
302304 x_dot = None ,
303305 u = None ,
304306 feature_names : Optional [list [str ]] = None ,
307+ sample_weight = None ,
305308 ):
306309 """
307310 Fit a SINDy model.
@@ -342,6 +345,11 @@ def fit(
342345 feature_names : list of string, length n_input_features, optional
343346 Names for the input features (e.g. :code:`['x', 'y', 'z']`).
344347 If None, will use :code:`['x0', 'x1', ...]`.
348+
349+ sample_weight : float or array-like of shape (n_samples,), optional
350+ Per-sample weights for the regression. Passed internally to
351+ the optimizer (e.g. STLSQ). Supports compatibility with
352+ scikit-learn tools such as GridSearchCV when using weighted data.
345353
346354 Returns
347355 -------
@@ -371,14 +379,18 @@ def fit(
371379
372380 self .feature_names = feature_names
373381
382+ # User may give one weight per trajectory or one weight per sample
383+ if sample_weight is not None :
384+ sample_weight = _expand_sample_weights (sample_weight , x )
385+
374386 steps = [
375387 ("features" , self .feature_library ),
376388 ("shaping" , SampleConcatter ()),
377389 ("model" , self .optimizer ),
378390 ]
379391 x_dot = concat_sample_axis (x_dot )
380392 self .model = Pipeline (steps )
381- self .model .fit (x , x_dot )
393+ self .model .fit (x , x_dot , model__sample_weight = sample_weight )
382394 self ._fit_shape ()
383395
384396 return self
@@ -412,6 +424,7 @@ def predict(self, x, u=None):
412424 x , _ , u = _comprehend_and_validate_inputs (x , 1 , None , u , self .feature_library )
413425
414426 check_is_fitted (self , "model" )
427+
415428 if self .n_control_features_ > 0 and u is None :
416429 raise TypeError ("Model was fit using control variables, so u is required" )
417430 if self .n_control_features_ == 0 and u is not None :
@@ -467,7 +480,7 @@ def print(self, lhs=None, precision=3, **kwargs):
467480 names = f"{ lhs [i ]} "
468481 print (f"{ names } = { eqn } " , ** kwargs )
469482
470- def score (self , x , t , x_dot = None , u = None , metric = r2_score , ** metric_kws ):
483+ def score (self , x , t , x_dot = None , u = None , metric = r2_score , sample_weight = None , ** metric_kws ):
471484 """
472485 Returns a score for the time derivative prediction produced by the model.
473486
@@ -500,9 +513,14 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
500513 See `Scikit-learn \
501514 <https://scikit-learn.org/stable/modules/model_evaluation.html>`_
502515 for more options.
516+
517+ sample_weight : array-like of shape (n_samples,), optional
518+ Per-sample weights passed directly to the metric. This is the
519+ preferred way to supply weights.
503520
504521 metric_kws: dict, optional
505522 Optional keyword arguments to pass to the metric function.
523+
506524
507525 Returns
508526 -------
@@ -523,10 +541,21 @@ def score(self, x, t, x_dot=None, u=None, metric=r2_score, **metric_kws):
523541
524542 x , x_dot = self ._process_trajectories (x , t , x_dot )
525543
544+ if sample_weight is not None :
545+ sample_weight = _expand_sample_weights (sample_weight , x )
546+
526547 x_dot = concat_sample_axis (x_dot )
527548 x_dot_predict = concat_sample_axis (x_dot_predict )
528549
529- x_dot , x_dot_predict = drop_nan_samples (x_dot , x_dot_predict )
550+ if sample_weight is not None :
551+ x_dot , x_dot_predict , good_idx = drop_nan_samples (
552+ x_dot , x_dot_predict , return_indices = True
553+ )
554+ sample_weight = sample_weight [good_idx ]
555+ metric_kws = {** metric_kws , "sample_weight" : sample_weight }
556+ else :
557+ x_dot , x_dot_predict = drop_nan_samples (x_dot , x_dot_predict )
558+
530559 return metric (x_dot , x_dot_predict , ** metric_kws )
531560
532561 def _process_trajectories (self , x , t , x_dot ):
@@ -910,3 +939,43 @@ def comprehend_and_validate(arr, t):
910939 )
911940 u = [comprehend_and_validate (ui , ti ) for ui , ti in _zip_like_sequence (u , t )]
912941 return x , x_dot , u
942+
943+ def _expand_sample_weights (sample_weight , trajectories ):
944+ """Expand trajectory-level weights to per-sample weights.
945+
946+ Parameters
947+ ----------
948+ sample_weight : array-like of shape (n_trajectories,) or (n_samples,), default=None
949+ If length == n_trajectories, each trajectory weight is expanded to cover
950+ all samples in that trajectory.
951+ If length == n_samples, interpreted as per-sample weights directly.
952+ If None, uniform weighting is applied.
953+
954+ trajectories : list of array-like
955+ The list of input trajectories, each shape (n_samples_i, n_features).
956+
957+ Returns
958+ -------
959+ sample_weight : ndarray of shape (sum_i n_samples_i,)
960+ Per-sample weights, ready to use in metrics.
961+ """
962+ if sample_weight is None :
963+ return None
964+
965+ sample_weight = np .asarray (sample_weight )
966+
967+ n_traj = len (trajectories )
968+ n_samples_total = sum (len (traj ) for traj in trajectories )
969+
970+ if sample_weight .ndim == 1 and len (sample_weight ) == n_traj :
971+ # Efficient expansion using np.repeat
972+ traj_lengths = [len (traj ) for traj in trajectories ]
973+ return np .repeat (sample_weight , traj_lengths )
974+
975+ if sample_weight .ndim == 1 and len (sample_weight ) == n_samples_total :
976+ return sample_weight
977+
978+ raise ValueError (
979+ f"sample_weight must be length { n_traj } (per trajectory) or "
980+ f"{ n_samples_total } (per sample), got { len (sample_weight )} "
981+ )
0 commit comments