@@ -944,63 +944,56 @@ def comprehend_and_validate(arr, t):
944944 return x , x_dot , u
945945
946946def _expand_sample_weights (sample_weight , trajectories ):
947- """Expand trajectory-level weights to per-sample or per-component weights.
947+ """Expand trajectory-level weights to per-sample or per-component weights."""
948948
949- Parameters
950- ----------
951- sample_weight : array-like
952- Can be
953- - None
954- - scalar
955- - shape (n_traj,) : one weight per trajectory
956- - shape (n_samples_total,) : one weight per sample
957- - shape (n_samples_total, n_tgt) : per-sample, per-target weights
958- - shape (n_traj, n_tgt) : one weight per trajectory, per target
959-
960- trajectories : list of arrays
961- Each trajectory, shape (n_samples_i, n_features).
962-
963- Returns
964- -------
965- ndarray
966- Expanded weights:
967- - (n_samples_total,)
968- - (n_samples_total, n_tgt)
969- """
970949 if sample_weight is None :
971950 return None
972951
973952 sample_weight = np .asarray (sample_weight )
974953 n_traj = len (trajectories )
975954 n_samples_total = sum (len (traj ) for traj in trajectories )
976955
977- # case: one weight per trajectory
978- if sample_weight .ndim == 1 and len ( sample_weight ) == n_traj :
956+ # (1) trajectory-level
957+ if sample_weight .ndim == 1 and sample_weight . shape [ 0 ] == n_traj :
979958 expanded = []
980959 for w , traj in zip (sample_weight , trajectories ):
981960 expanded .extend ([w ] * len (traj ))
982961 return np .asarray (expanded )
983962
984- # case: one weight per sample
985- if sample_weight .ndim == 1 and len ( sample_weight ) == n_samples_total :
963+ # (2) sample-level
964+ if sample_weight .ndim == 1 and sample_weight . shape [ 0 ] == n_samples_total :
986965 return sample_weight
987966
988- # case: per-sample, per-target
967+ # (3) per-sample, per-target
989968 if sample_weight .ndim == 2 and sample_weight .shape [0 ] == n_samples_total :
990969 return sample_weight
991970
992- # case: per-trajectory, per-target
971+ # (4) per-trajectory, per-target
993972 if sample_weight .ndim == 2 and sample_weight .shape [0 ] == n_traj :
994973 expanded = []
995974 for w_vec , traj in zip (sample_weight , trajectories ):
975+ expanded .append (np .tile (w_vec , (len (traj ), 1 )))
976+ return np .vstack (expanded )
977+
978+ # (5) per-trajectory, per-time-step, per-target
979+ if sample_weight .ndim == 3 and sample_weight .shape [0 ] == n_traj :
980+ expanded = []
981+ for w_block , traj in zip (sample_weight , trajectories ):
996982 n = len (traj )
997- expanded .append (np .tile (w_vec , (n , 1 ))) # repeat per sample in that traj
983+ if w_block .shape [0 ] != n :
984+ raise ValueError (
985+ f"sample_weight time dimension { w_block .shape [0 ]} "
986+ f"does not match trajectory length { n } "
987+ )
988+ expanded .append (w_block ) # shape (n, n_targets)
998989 return np .vstack (expanded )
999990
1000991 raise ValueError (
1001- f"sample_weight must be length { n_traj } (per trajectory), "
1002- f"{ n_samples_total } (per sample), "
1003- f"({ n_samples_total } , n_targets) (per sample/target), or "
1004- f"({ n_traj } , n_targets) (per trajectory/target). "
1005- f"Got { sample_weight .shape } "
992+ f"sample_weight must be one of:\n "
993+ f" (n_traj,), (n_samples_total,), "
994+ f" (n_samples_total, n_targets), "
995+ f" (n_traj, n_targets), "
996+ f" (n_traj, n_time, n_targets).\n "
997+ f"Got { sample_weight .shape } ."
1006998 )
999+
0 commit comments