Skip to content

Commit 2b6a2eb

Browse files
Adjusted expand weights
1 parent 95dfb9e commit 2b6a2eb

File tree

1 file changed

+61
-122
lines changed

1 file changed

+61
-122
lines changed

pysindy/_core.py

Lines changed: 61 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,10 @@ def fit(
383383
self.feature_names = feature_names
384384

385385
if sample_weight is not None:
386-
# Choose appropriate expansion depending on the library type
387-
lib = self.feature_library.__class__.__name__
388-
if lib in ("WeakPDELibrary", "WeightedWeakPDELibrary"):
389-
sample_weight = _expand_weak_sample_weights(sample_weight, x, self.feature_library)
390-
else:
391-
sample_weight = _expand_sample_weights(sample_weight, x)
386+
mode = "weak" if "Weak" in self.feature_library.__class__.__name__ else "standard"
387+
sample_weight = _expand_sample_weights(
388+
sample_weight, x, feature_library=self.feature_library, mode=mode
389+
)
392390

393391
steps = [
394392
("features", self.feature_library),
@@ -947,137 +945,78 @@ def comprehend_and_validate(arr, t):
947945
u = [comprehend_and_validate(ui, ti) for ui, ti in _zip_like_sequence(u, t)]
948946
return x, x_dot, u
949947

950-
def _assert_sample_weights(sample_weight, trajectories):
951-
"""Validate per-trajectory sample_weight input.
948+
def _expand_sample_weights(sample_weight, trajectories, feature_library=None, mode="standard"):
949+
"""Expand per-trajectory sample weights for estimators or weak-form libraries.
952950
953-
Requirements enforced:
954-
- `sample_weight` must be a Python sequence (e.g., list or tuple), not a scalar
955-
and not a single numpy array.
956-
- Length of the sequence must equal the number of trajectories.
957-
- Each element must be array-like with first axis length equal to the
958-
corresponding trajectory's `n_time`.
959-
- If an element is 2D, its second dimension must be 1 (broadcast) or equal
960-
to the trajectory coordinate count `n_coord`.
951+
Parameters
952+
----------
953+
sample_weight : sequence of array-like or None
954+
Per-trajectory sample weights. Each element corresponds to one trajectory.
955+
trajectories : sequence
956+
Sequence of trajectory objects, each with attributes `n_time` and `n_coord`.
957+
feature_library : object, optional
958+
Library instance, required when mode='weak'.
959+
mode : {'standard', 'weak'}, default='standard'
960+
Expansion mode:
961+
- 'standard' : Concatenate weights per sample or per coordinate.
962+
- 'weak' : Expand weights for weak-form (integral) test functions.
961963
962964
Returns
963965
-------
964-
validated_list : list of numpy arrays
965-
The per-trajectory arrays (not concatenated).
966+
np.ndarray or None
967+
Concatenated and expanded sample weights, or None if no weights are given.
966968
"""
967-
968969
if sample_weight is None:
969970
return None
970971

971972
if not (isinstance(sample_weight, Sequence) and not isinstance(sample_weight, np.ndarray)):
972-
raise ValueError(
973-
"sample_weight must be a sequence (e.g. list or tuple) with one entry per trajectory. "
974-
"Do not pass a scalar or a single numpy array here."
975-
)
973+
raise ValueError("sample_weight must be a list or tuple, not a scalar or numpy array.")
976974

977975
if len(sample_weight) != len(trajectories):
978-
raise ValueError(
979-
f"When passing a sequence of sample_weight, its length ({len(sample_weight)}) must equal the number of trajectories ({len(trajectories)})"
980-
)
976+
raise ValueError("sample_weight length must match number of trajectories.")
981977

978+
# --- Validate shape consistency ---
982979
validated = []
983980
for sw, traj in zip(sample_weight, trajectories):
984-
a = np.asarray(sw)
985-
if a.ndim == 0:
986-
validated.append(a)
981+
arr = np.asarray(sw)
982+
if arr.ndim == 0:
983+
validated.append(arr)
987984
continue
988-
if a.shape[0] != traj.n_time:
989-
raise ValueError(
990-
f"sample_weight entry length ({a.shape[0]}) does not match trajectory length ({traj.n_time})"
991-
)
992-
if a.ndim == 2 and a.shape[1] not in (1, traj.n_coord):
993-
raise ValueError(
994-
f"sample_weight entry second dimension ({a.shape[1]}) must be 1 or equal to the number of coordinates ({traj.n_coord})"
995-
)
996-
validated.append(a)
997-
return validated
998-
999-
1000-
def _expand_sample_weights(sample_weight, trajectories):
1001-
"""Concatenate per-trajectory sample weights into final array for estimators.
1002-
1003-
Expects `sample_weight` to be a sequence (validated by _assert_sample_weights).
1004-
Returns either:
1005-
- 1D array of shape (N_total_samples,) when weights are per-sample scalars, or
1006-
- 2D array of shape (N_total_samples, n_coord) when per-sample-per-coordinate weights
1007-
are provided (or when 1D weights are promoted to match n_coord).
1008-
"""
1009-
1010-
sw_list = _assert_sample_weights(sample_weight, trajectories)
1011-
1012-
if sw_list is None:
1013-
return None
1014-
1015-
# Determine common coordinate count
985+
if arr.shape[0] != traj.n_time:
986+
raise ValueError("sample_weight entry length does not match trajectory length.")
987+
if arr.ndim == 2 and arr.shape[1] not in (1, traj.n_coord):
988+
raise ValueError("sample_weight 2D second dim must be 1 or equal to n_coord.")
989+
validated.append(arr)
990+
991+
# --- Weak-form expansion ---
992+
if mode == "weak":
993+
n_funcs = getattr(feature_library, "K", 1)
994+
if n_funcs is None:
995+
warnings.warn("feature_library missing 'K'; assuming 1 test function.")
996+
n_funcs = 1
997+
return np.concatenate([np.repeat(np.asarray(sw), n_funcs, axis=0) for sw in validated])
998+
999+
# --- Standard expansion ---
10161000
n_coords = {int(t.n_coord) for t in trajectories}
10171001
if len(n_coords) != 1:
1018-
raise ValueError("All trajectories must have the same number of coordinate components")
1019-
n_coord = next(iter(n_coords))
1020-
1021-
arrs = []
1022-
dims = []
1023-
for a in sw_list:
1024-
a = np.asarray(a)
1025-
if a.ndim == 1:
1026-
arrs.append(a)
1027-
dims.append(1)
1028-
else: # a.ndim == 2
1029-
if a.shape[1] == 1:
1030-
arrs.append(a[:, 0])
1031-
dims.append(1)
1032-
else:
1033-
arrs.append(a)
1034-
dims.append(n_coord)
1035-
1036-
# All 1D -> concatenate to 1D
1037-
if all(d == 1 for d in dims):
1038-
return np.concatenate([a.reshape(-1) for a in arrs], axis=0)
1039-
1040-
# Otherwise promote 1D arrays to 2D and concatenate
1041-
promoted = []
1042-
for a, d in zip(arrs, dims):
1043-
if d == 1:
1044-
promoted.append(np.broadcast_to(a.reshape(-1, 1), (a.shape[0], n_coord)))
1045-
else:
1046-
promoted.append(a)
1047-
return np.concatenate(promoted, axis=0)
1048-
1049-
1050-
def _expand_weak_sample_weights(sample_weight, trajectories, feature_library):
1051-
"""Expand sample weights for weak-form (integral) SINDy libraries.
1052-
1053-
Each trajectory contributes multiple weak test functions (integrals).
1054-
This expands the sample weights to match the number of weak test functions
1055-
per trajectory, and concatenates across all trajectories.
1056-
1057-
Returns
1058-
-------
1059-
np.ndarray
1060-
Expanded weights with shape matching the number of weak test function
1061-
evaluations across all trajectories.
1062-
"""
1063-
sw_list = _assert_sample_weights(sample_weight, trajectories)
1064-
if sw_list is None:
1065-
return None
1066-
1067-
# Number of test functions in the weak library
1068-
n_test_funcs = getattr(feature_library, "K", None)
1069-
if n_test_funcs is None:
1070-
warnings.warn(
1071-
"Weak-form feature library did not define `n_test_functions`; "
1072-
"assuming 1 weight per trajectory."
1073-
)
1074-
n_test_funcs = 1
1075-
1076-
expanded = []
1077-
for sw, traj in zip(sw_list, trajectories):
1078-
# Each trajectory contributes n_test_funcs weak equations
1079-
sw = np.asarray(sw)
1080-
# Expand weights by repeating for each weak test function
1081-
sw_expanded = np.repeat(sw, n_test_funcs, axis=0)
1082-
expanded.append(sw_expanded)
1002+
raise ValueError("All trajectories must have the same n_coord.")
1003+
n_coord = n_coords.pop()
1004+
1005+
processed = []
1006+
for arr in validated:
1007+
arr = np.asarray(arr)
1008+
if arr.ndim == 1:
1009+
arr = arr.reshape(-1, 1)
1010+
elif arr.ndim == 2 and arr.shape[1] == 1:
1011+
pass # already correct shape
1012+
processed.append(arr)
1013+
1014+
# Promote to n_coord if any arrays have multiple coordinates
1015+
is_scalar_weight = all(a.shape[1] == 1 for a in processed)
1016+
if is_scalar_weight:
1017+
return np.concatenate([a.ravel() for a in processed])
1018+
expanded = [
1019+
np.broadcast_to(a, (a.shape[0], n_coord)) if a.shape[1] == 1 else a
1020+
for a in processed
1021+
]
10831022
return np.concatenate(expanded, axis=0)

0 commit comments

Comments
 (0)