Skip to content

Commit 5fb48c8

Browse files
Adjust weights to account for the more general case
1 parent 1a2a5ec commit 5fb48c8

File tree

1 file changed

+27
-34
lines changed

1 file changed

+27
-34
lines changed

pysindy/_core.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -944,63 +944,56 @@ def comprehend_and_validate(arr, t):
944944
return x, x_dot, u
945945

946946
def _expand_sample_weights(sample_weight, trajectories):
947-
"""Expand trajectory-level weights to per-sample or per-component weights.
947+
"""Expand trajectory-level weights to per-sample or per-component weights."""
948948

949-
Parameters
950-
----------
951-
sample_weight : array-like
952-
Can be
953-
- None
954-
- scalar
955-
- shape (n_traj,) : one weight per trajectory
956-
- shape (n_samples_total,) : one weight per sample
957-
- shape (n_samples_total, n_tgt) : per-sample, per-target weights
958-
- shape (n_traj, n_tgt) : one weight per trajectory, per target
959-
960-
trajectories : list of arrays
961-
Each trajectory, shape (n_samples_i, n_features).
962-
963-
Returns
964-
-------
965-
ndarray
966-
Expanded weights:
967-
- (n_samples_total,)
968-
- (n_samples_total, n_tgt)
969-
"""
970949
if sample_weight is None:
971950
return None
972951

973952
sample_weight = np.asarray(sample_weight)
974953
n_traj = len(trajectories)
975954
n_samples_total = sum(len(traj) for traj in trajectories)
976955

977-
# case: one weight per trajectory
978-
if sample_weight.ndim == 1 and len(sample_weight) == n_traj:
956+
# (1) trajectory-level
957+
if sample_weight.ndim == 1 and sample_weight.shape[0] == n_traj:
979958
expanded = []
980959
for w, traj in zip(sample_weight, trajectories):
981960
expanded.extend([w] * len(traj))
982961
return np.asarray(expanded)
983962

984-
# case: one weight per sample
985-
if sample_weight.ndim == 1 and len(sample_weight) == n_samples_total:
963+
# (2) sample-level
964+
if sample_weight.ndim == 1 and sample_weight.shape[0] == n_samples_total:
986965
return sample_weight
987966

988-
# case: per-sample, per-target
967+
# (3) per-sample, per-target
989968
if sample_weight.ndim == 2 and sample_weight.shape[0] == n_samples_total:
990969
return sample_weight
991970

992-
# case: per-trajectory, per-target
971+
# (4) per-trajectory, per-target
993972
if sample_weight.ndim == 2 and sample_weight.shape[0] == n_traj:
994973
expanded = []
995974
for w_vec, traj in zip(sample_weight, trajectories):
975+
expanded.append(np.tile(w_vec, (len(traj), 1)))
976+
return np.vstack(expanded)
977+
978+
# (5) per-trajectory, per-time-step, per-target
979+
if sample_weight.ndim == 3 and sample_weight.shape[0] == n_traj:
980+
expanded = []
981+
for w_block, traj in zip(sample_weight, trajectories):
996982
n = len(traj)
997-
expanded.append(np.tile(w_vec, (n, 1))) # repeat per sample in that traj
983+
if w_block.shape[0] != n:
984+
raise ValueError(
985+
f"sample_weight time dimension {w_block.shape[0]} "
986+
f"does not match trajectory length {n}"
987+
)
988+
expanded.append(w_block) # shape (n, n_targets)
998989
return np.vstack(expanded)
999990

1000991
raise ValueError(
1001-
f"sample_weight must be length {n_traj} (per trajectory), "
1002-
f"{n_samples_total} (per sample), "
1003-
f"({n_samples_total}, n_targets) (per sample/target), or "
1004-
f"({n_traj}, n_targets) (per trajectory/target). "
1005-
f"Got {sample_weight.shape}"
992+
f"sample_weight must be one of:\n"
993+
f" (n_traj,), (n_samples_total,), "
994+
f" (n_samples_total, n_targets), "
995+
f" (n_traj, n_targets), "
996+
f" (n_traj, n_time, n_targets).\n"
997+
f"Got {sample_weight.shape}."
1006998
)
999+

0 commit comments

Comments
 (0)