Skip to content

Commit 9084812

Browse files
Adjusted set_fit_reqeuste
1 parent 9accf8b commit 9084812

File tree

1 file changed

+17
-31
lines changed

1 file changed

+17
-31
lines changed

pysindy/_core.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -942,42 +942,28 @@ def comprehend_and_validate(arr, t):
942942
u = [comprehend_and_validate(ui, ti) for ui, ti in _zip_like_sequence(u, t)]
943943
return x, x_dot, u
944944

945-
def _expand_sample_weights(sample_weight, trajectories):
946-
"""Expand trajectory-level weights to per-sample weights.
947945

948-
Parameters
949-
----------
950-
sample_weight : array-like of shape (n_trajectories,) or (n_samples,), default=None
951-
If length == n_trajectories, each trajectory weight is expanded to cover
952-
all samples in that trajectory.
953-
If length == n_samples, interpreted as per-sample weights directly.
954-
If None, uniform weighting is applied.
955-
956-
trajectories : list of array-like
957-
The list of input trajectories, each shape (n_samples_i, n_features).
958-
959-
Returns
960-
-------
961-
sample_weight : ndarray of shape (sum_i n_samples_i,)
962-
Per-sample weights, ready to use in metrics.
963-
"""
946+
def _expand_sample_weights(sample_weight, trajectories):
964947
if sample_weight is None:
965948
return None
966949

967-
sample_weight = np.asarray(sample_weight)
968-
969-
n_traj = len(trajectories)
970-
n_samples_total = sum(len(traj) for traj in trajectories)
971-
972-
if sample_weight.ndim == 1 and len(sample_weight) == n_traj:
973-
# Efficient expansion using np.repeat
974-
traj_lengths = [len(traj) for traj in trajectories]
975-
return np.repeat(sample_weight, traj_lengths)
950+
# Case: list of arrays, one per trajectory
951+
if isinstance(sample_weight, (list, tuple)):
952+
if len(sample_weight) != len(trajectories):
953+
raise ValueError(
954+
f"Expected {len(trajectories)} weight blocks, got {len(sample_weight)}"
955+
)
956+
return np.concatenate([np.asarray(w) for w in sample_weight])
976957

977-
if sample_weight.ndim == 1 and len(sample_weight) == n_samples_total:
978-
return sample_weight
958+
# Case: already concatenated 1D array
959+
w = np.asarray(sample_weight)
960+
total = sum(len(traj) for traj in trajectories)
961+
if w.ndim == 1 and w.shape[0] == total:
962+
return w
979963

980964
raise ValueError(
981-
f"sample_weight must be length {n_traj} (per trajectory) or "
982-
f"{n_samples_total} (per sample), got {len(sample_weight)}"
965+
f"sample_weight must be list of arrays or shape ({total},), "
966+
f"got {w.shape}"
983967
)
968+
969+

0 commit comments

Comments
 (0)