@@ -942,42 +942,28 @@ def comprehend_and_validate(arr, t):
942942 u = [comprehend_and_validate (ui , ti ) for ui , ti in _zip_like_sequence (u , t )]
943943 return x , x_dot , u
944944
945- def _expand_sample_weights (sample_weight , trajectories ):
946- """Expand trajectory-level weights to per-sample weights.
947945
948- Parameters
949- ----------
950- sample_weight : array-like of shape (n_trajectories,) or (n_samples,), default=None
951- If length == n_trajectories, each trajectory weight is expanded to cover
952- all samples in that trajectory.
953- If length == n_samples, interpreted as per-sample weights directly.
954- If None, uniform weighting is applied.
955-
956- trajectories : list of array-like
957- The list of input trajectories, each shape (n_samples_i, n_features).
958-
959- Returns
960- -------
961- sample_weight : ndarray of shape (sum_i n_samples_i,)
962- Per-sample weights, ready to use in metrics.
963- """
946+ def _expand_sample_weights (sample_weight , trajectories ):
964947 if sample_weight is None :
965948 return None
966949
967- sample_weight = np .asarray (sample_weight )
968-
969- n_traj = len (trajectories )
970- n_samples_total = sum (len (traj ) for traj in trajectories )
971-
972- if sample_weight .ndim == 1 and len (sample_weight ) == n_traj :
973- # Efficient expansion using np.repeat
974- traj_lengths = [len (traj ) for traj in trajectories ]
975- return np .repeat (sample_weight , traj_lengths )
950+ # Case: list of arrays, one per trajectory
951+ if isinstance (sample_weight , (list , tuple )):
952+ if len (sample_weight ) != len (trajectories ):
953+ raise ValueError (
954+ f"Expected { len (trajectories )} weight blocks, got { len (sample_weight )} "
955+ )
956+ return np .concatenate ([np .asarray (w ) for w in sample_weight ])
976957
977- if sample_weight .ndim == 1 and len (sample_weight ) == n_samples_total :
978- return sample_weight
958+ # Case: already concatenated 1D array
959+ w = np .asarray (sample_weight )
960+ total = sum (len (traj ) for traj in trajectories )
961+ if w .ndim == 1 and w .shape [0 ] == total :
962+ return w
979963
980964 raise ValueError (
981- f"sample_weight must be length { n_traj } (per trajectory) or "
982- f"{ n_samples_total } (per sample), got { len ( sample_weight ) } "
965+ f"sample_weight must be list of arrays or shape ( { total } ,), "
966+ f"got { w . shape } "
983967 )
968+
969+
0 commit comments