Skip to content

Commit 2e983a4

Browse files
Adjusted expand_weights
1 parent 97eae34 commit 2e983a4

File tree

1 file changed

+54
-18
lines changed

1 file changed

+54
-18
lines changed

pysindy/_core.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -943,28 +943,64 @@ def comprehend_and_validate(arr, t):
943943
u = [comprehend_and_validate(ui, ti) for ui, ti in _zip_like_sequence(u, t)]
944944
return x, x_dot, u
945945

946-
947946
def _expand_sample_weights(sample_weight, trajectories):
947+
"""Expand trajectory-level weights to per-sample or per-component weights.
948+
949+
Parameters
950+
----------
951+
sample_weight : array-like
952+
Can be
953+
- None
954+
- scalar
955+
- shape (n_traj,) : one weight per trajectory
956+
- shape (n_samples_total,) : one weight per sample
957+
- shape (n_samples_total, n_tgt) : per-sample, per-target weights
958+
- shape (n_traj, n_tgt) : one weight per trajectory, per target
959+
960+
trajectories : list of arrays
961+
Each trajectory, shape (n_samples_i, n_features).
962+
963+
Returns
964+
-------
965+
ndarray
966+
Expanded weights:
967+
- (n_samples_total,)
968+
- (n_samples_total, n_tgt)
969+
"""
948970
if sample_weight is None:
949971
return None
950972

951-
# Case: list of arrays, one per trajectory
952-
if isinstance(sample_weight, (list, tuple)):
953-
if len(sample_weight) != len(trajectories):
954-
raise ValueError(
955-
f"Expected {len(trajectories)} weight blocks, got {len(sample_weight)}"
956-
)
957-
return np.concatenate([np.asarray(w) for w in sample_weight])
958-
959-
# Case: already concatenated 1D array
960-
w = np.asarray(sample_weight)
961-
total = sum(len(traj) for traj in trajectories)
962-
if w.ndim == 1 and w.shape[0] == total:
963-
return w
973+
sample_weight = np.asarray(sample_weight)
974+
n_traj = len(trajectories)
975+
n_samples_total = sum(len(traj) for traj in trajectories)
976+
977+
# case: one weight per trajectory
978+
if sample_weight.ndim == 1 and len(sample_weight) == n_traj:
979+
expanded = []
980+
for w, traj in zip(sample_weight, trajectories):
981+
expanded.extend([w] * len(traj))
982+
return np.asarray(expanded)
983+
984+
# case: one weight per sample
985+
if sample_weight.ndim == 1 and len(sample_weight) == n_samples_total:
986+
return sample_weight
987+
988+
# case: per-sample, per-target
989+
if sample_weight.ndim == 2 and sample_weight.shape[0] == n_samples_total:
990+
return sample_weight
991+
992+
# case: per-trajectory, per-target
993+
if sample_weight.ndim == 2 and sample_weight.shape[0] == n_traj:
994+
expanded = []
995+
for w_vec, traj in zip(sample_weight, trajectories):
996+
n = len(traj)
997+
expanded.append(np.tile(w_vec, (n, 1))) # repeat per sample in that traj
998+
return np.vstack(expanded)
964999

9651000
raise ValueError(
966-
f"sample_weight must be list of arrays or shape ({total},), "
967-
f"got {w.shape}"
1001+
f"sample_weight must be length {n_traj} (per trajectory), "
1002+
f"{n_samples_total} (per sample), "
1003+
f"({n_samples_total}, n_targets) (per sample/target), or "
1004+
f"({n_traj}, n_targets) (per trajectory/target). "
1005+
f"Got {sample_weight.shape}"
9681006
)
969-
970-

0 commit comments

Comments
 (0)