@@ -943,28 +943,64 @@ def comprehend_and_validate(arr, t):
943943 u = [comprehend_and_validate (ui , ti ) for ui , ti in _zip_like_sequence (u , t )]
944944 return x , x_dot , u
945945
946-
947946def _expand_sample_weights (sample_weight , trajectories ):
947+ """Expand trajectory-level weights to per-sample or per-component weights.
948+
949+ Parameters
950+ ----------
951+ sample_weight : array-like
952+ Can be
953+ - None
954+ - scalar
955+ - shape (n_traj,) : one weight per trajectory
956+ - shape (n_samples_total,) : one weight per sample
957+ - shape (n_samples_total, n_tgt) : per-sample, per-target weights
958+ - shape (n_traj, n_tgt) : one weight per trajectory, per target
959+
960+ trajectories : list of arrays
961+ Each trajectory, shape (n_samples_i, n_features).
962+
963+ Returns
964+ -------
965+ ndarray
966+ Expanded weights:
967+ - (n_samples_total,)
968+ - (n_samples_total, n_tgt)
969+ """
948970 if sample_weight is None :
949971 return None
950972
951- # Case: list of arrays, one per trajectory
952- if isinstance (sample_weight , (list , tuple )):
953- if len (sample_weight ) != len (trajectories ):
954- raise ValueError (
955- f"Expected { len (trajectories )} weight blocks, got { len (sample_weight )} "
956- )
957- return np .concatenate ([np .asarray (w ) for w in sample_weight ])
958-
959- # Case: already concatenated 1D array
960- w = np .asarray (sample_weight )
961- total = sum (len (traj ) for traj in trajectories )
962- if w .ndim == 1 and w .shape [0 ] == total :
963- return w
973+ sample_weight = np .asarray (sample_weight )
974+ n_traj = len (trajectories )
975+ n_samples_total = sum (len (traj ) for traj in trajectories )
976+
977+ # case: one weight per trajectory
978+ if sample_weight .ndim == 1 and len (sample_weight ) == n_traj :
979+ expanded = []
980+ for w , traj in zip (sample_weight , trajectories ):
981+ expanded .extend ([w ] * len (traj ))
982+ return np .asarray (expanded )
983+
984+ # case: one weight per sample
985+ if sample_weight .ndim == 1 and len (sample_weight ) == n_samples_total :
986+ return sample_weight
987+
988+ # case: per-sample, per-target
989+ if sample_weight .ndim == 2 and sample_weight .shape [0 ] == n_samples_total :
990+ return sample_weight
991+
992+ # case: per-trajectory, per-target
993+ if sample_weight .ndim == 2 and sample_weight .shape [0 ] == n_traj :
994+ expanded = []
995+ for w_vec , traj in zip (sample_weight , trajectories ):
996+ n = len (traj )
997+ expanded .append (np .tile (w_vec , (n , 1 ))) # repeat per sample in that traj
998+ return np .vstack (expanded )
964999
9651000 raise ValueError (
966- f"sample_weight must be list of arrays or shape ({ total } ,), "
967- f"got { w .shape } "
1001+ f"sample_weight must be length { n_traj } (per trajectory), "
1002+ f"{ n_samples_total } (per sample), "
1003+ f"({ n_samples_total } , n_targets) (per sample/target), or "
1004+ f"({ n_traj } , n_targets) (per trajectory/target). "
1005+ f"Got { sample_weight .shape } "
9681006 )
969-
970-
0 commit comments