Skip to content

Commit 95dfb9e

Browse files
Weighted Weak PDE
1 parent 6821283 commit 95dfb9e

File tree

4 files changed

+210
-13
lines changed

4 files changed

+210
-13
lines changed

examples/12_weakform_SINDy_examples.ipynb

Lines changed: 9 additions & 9 deletions
Large diffs are not rendered by default.

pysindy/_core.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,12 @@ def fit(
383383
self.feature_names = feature_names
384384

385385
if sample_weight is not None:
386-
sample_weight = _expand_sample_weights(sample_weight, x)
386+
# Choose appropriate expansion depending on the library type
387+
lib = self.feature_library.__class__.__name__
388+
if lib in ("WeakPDELibrary", "WeightedWeakPDELibrary"):
389+
sample_weight = _expand_weak_sample_weights(sample_weight, x, self.feature_library)
390+
else:
391+
sample_weight = _expand_sample_weights(sample_weight, x)
387392

388393
steps = [
389394
("features", self.feature_library),
@@ -978,9 +983,8 @@ def _assert_sample_weights(sample_weight, trajectories):
978983
for sw, traj in zip(sample_weight, trajectories):
979984
a = np.asarray(sw)
980985
if a.ndim == 0:
981-
raise ValueError(
982-
"Each element of sample_weight must be array-like with length equal to the trajectory time dimension"
983-
)
986+
validated.append(a)
987+
continue
984988
if a.shape[0] != traj.n_time:
985989
raise ValueError(
986990
f"sample_weight entry length ({a.shape[0]}) does not match trajectory length ({traj.n_time})"
@@ -1042,3 +1046,38 @@ def _expand_sample_weights(sample_weight, trajectories):
10421046
promoted.append(a)
10431047
return np.concatenate(promoted, axis=0)
10441048

1049+
1050+
def _expand_weak_sample_weights(sample_weight, trajectories, feature_library):
1051+
"""Expand sample weights for weak-form (integral) SINDy libraries.
1052+
1053+
Each trajectory contributes multiple weak test functions (integrals).
1054+
This expands the sample weights to match the number of weak test functions
1055+
per trajectory, and concatenates across all trajectories.
1056+
1057+
Returns
1058+
-------
1059+
np.ndarray
1060+
Expanded weights with shape matching the number of weak test function
1061+
evaluations across all trajectories.
1062+
"""
1063+
sw_list = _assert_sample_weights(sample_weight, trajectories)
1064+
if sw_list is None:
1065+
return None
1066+
1067+
# Number of test functions in the weak library
1068+
n_test_funcs = getattr(feature_library, "K", None)
1069+
if n_test_funcs is None:
1070+
warnings.warn(
1071+
"Weak-form feature library did not define `n_test_functions`; "
1072+
"assuming 1 weight per trajectory."
1073+
)
1074+
n_test_funcs = 1
1075+
1076+
expanded = []
1077+
for sw, traj in zip(sw_list, trajectories):
1078+
# Each trajectory contributes n_test_funcs weak equations
1079+
sw = np.asarray(sw)
1080+
# Expand weights by repeating for each weak test function
1081+
sw_expanded = np.repeat(sw, n_test_funcs, axis=0)
1082+
expanded.append(sw_expanded)
1083+
return np.concatenate(expanded, axis=0)

pysindy/feature_library/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .polynomial_library import PolynomialLibrary
1111
from .sindy_pi_library import SINDyPILibrary
1212
from .weak_pde_library import WeakPDELibrary
13+
from .weighted_weak_pde_library import WeightedWeakPDELibrary
1314

1415
__all__ = [
1516
"ConcatLibrary",
@@ -21,6 +22,7 @@
2122
"PolynomialLibrary",
2223
"PDELibrary",
2324
"WeakPDELibrary",
25+
"WeightedWeakPDELibrary",
2426
"SINDyPILibrary",
2527
"ParameterizedLibrary",
2628
"base",
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import numpy as np
2+
from .weak_pde_library import WeakPDELibrary
3+
from ..utils import AxesArray
4+
5+
6+
class WeightedWeakPDELibrary(WeakPDELibrary):
7+
"""
8+
WeakPDELibrary with GLS whitening via a Cholesky factor built from the
9+
variance field on the spatiotemporal grid.
10+
11+
Parameters
12+
----------
13+
spatiotemporal_weights : ndarray, shape = spatiotemporal_grid.shape[:-1]
14+
Pointwise noise variances σ^2 on the grid (no feature axis).
15+
The covariance of the weak RHS is Cov[V] = M_y diag(σ^2) M_y^T.
16+
17+
Notes
18+
-----
19+
The whitener W = L^{-1}, with L L^T = Cov[V], is left-applied to both Θ and V.
20+
This implements min_x || W(Θ x - V) ||_2^2, i.e., GLS in the weak space.
21+
"""
22+
23+
def __init__(self, *args, spatiotemporal_weights=None, **kwargs):
24+
self.spatiotemporal_weights = spatiotemporal_weights
25+
self._L_chol = None # lower-triangular Cholesky factor of Cov[V]
26+
super().__init__(*args, **kwargs)
27+
28+
# ------------------------------ core whitening ------------------------------
29+
30+
def _build_whitener_from_variance(self):
31+
"""
32+
Construct L such that Cov[V] = L L^T with
33+
Cov[V]_{kℓ} = sum_{g ∈ grid} ( w_k[g] * w_ℓ[g] * σ^2[g] ),
34+
where w_k are the time-derivative weak weights on domain k.
35+
"""
36+
if self.spatiotemporal_weights is None:
37+
self._L_chol = None
38+
return
39+
40+
# --- robust weight-field shape handling ---
41+
base_grid = np.asarray(self.spatiotemporal_grid)
42+
expected = tuple(base_grid.shape[:-1]) # e.g. (Nx, Nt) for a 2D grid
43+
var_grid = np.asarray(self.spatiotemporal_weights)
44+
45+
if var_grid.shape == expected + (1,):
46+
var_grid = var_grid[..., 0]
47+
elif var_grid.shape != expected:
48+
raise ValueError(
49+
f"spatiotemporal_weights must have shape {expected} or {expected + (1,)}, "
50+
f"got {var_grid.shape}"
51+
)
52+
53+
# Flattened variance for convenient indexing
54+
var_flat = var_grid.ravel(order="C")
55+
grid_shape = expected
56+
K = self.K
57+
58+
idx_lists = []
59+
val_lists = []
60+
for k in range(K):
61+
# local multi-index grids (can be 1D, 2D, 3D… arrays)
62+
inds_axes = [np.asarray(ax, dtype=np.intp) for ax in self.inds_k[k]]
63+
grids = np.meshgrid(*inds_axes, indexing="ij")
64+
65+
# linearize to 1D!
66+
lin_idx = np.ravel_multi_index(tuple(grids), dims=grid_shape, order="C")
67+
lin_idx = lin_idx.ravel(order="C")
68+
69+
# corresponding weak RHS weights, flattened to 1D
70+
wk = np.asarray(self.fulltweights[k], dtype=float).ravel(order="C")
71+
72+
# ensure same length (paranoia check)
73+
if wk.shape[0] != lin_idx.shape[0]:
74+
raise RuntimeError(
75+
f"Weight/variance size mismatch on cell {k}: "
76+
f"wk has {wk.shape[0]} entries, indices have {lin_idx.shape[0]}"
77+
)
78+
79+
vals = wk * np.sqrt(var_flat[lin_idx])
80+
81+
idx_lists.append(lin_idx)
82+
val_lists.append(vals)
83+
84+
# Build Cov[V] = B B^T with B_{k,i} = w_k[i] * sqrt(var[i])
85+
Cov = np.zeros((K, K), dtype=float)
86+
for k in range(K):
87+
vk = val_lists[k]
88+
Cov[k, k] = np.dot(vk, vk)
89+
# off-diagonals via set intersection of supports
90+
idx_k = idx_lists[k]
91+
# Use a dict for fast overlap accumulation
92+
map_k = dict(zip(idx_k.tolist(), vk.tolist()))
93+
for ell in range(k + 1, K):
94+
s = 0.0
95+
idx_e = idx_lists[ell]
96+
v_e = val_lists[ell]
97+
map_e = dict(zip(idx_e.tolist(), v_e.tolist()))
98+
# iterate the smaller map
99+
if len(map_k) <= len(map_e):
100+
for j, vkj in map_k.items():
101+
ve = map_e.get(j)
102+
if ve is not None:
103+
s += vkj * ve
104+
else:
105+
for j, ve in map_e.items():
106+
vk_j = map_k.get(j)
107+
if vk_j is not None:
108+
s += vk_j * ve
109+
Cov[k, ell] = s
110+
Cov[ell, k] = s
111+
112+
# diagonal nugget for stability
113+
avg_diag = np.trace(Cov) / max(K, 1)
114+
nugget = 1e-12 * avg_diag
115+
Cov.flat[:: K + 1] += nugget
116+
117+
# robust Cholesky with fallback if needed
118+
try:
119+
self._L_chol = np.linalg.cholesky(Cov)
120+
except np.linalg.LinAlgError:
121+
# inflate nugget and retry once
122+
Cov.flat[:: K + 1] += max(1e-10, 1e-6 * avg_diag)
123+
self._L_chol = np.linalg.cholesky(Cov)
124+
125+
def _apply_whitener(self, A):
126+
"""Return L^{-1} A without forming L^{-1} explicitly."""
127+
if self._L_chol is None:
128+
return A
129+
# solve L X = A → X = L^{-1} A
130+
return np.linalg.solve(self._L_chol, A)
131+
132+
# ------------------------------ hooks ------------------------------
133+
134+
def _weak_form_setup(self):
135+
# parent builds inds_k and the weak weight tensors
136+
super()._weak_form_setup()
137+
# then build the GLS whitener from the variance field
138+
if self.spatiotemporal_weights is not None:
139+
self._build_whitener_from_variance()
140+
141+
def convert_u_dot_integral(self, u):
142+
Vy = super().convert_u_dot_integral(u) # (K, 1)
143+
Vy_w = self._apply_whitener(np.asarray(Vy))
144+
return AxesArray(Vy_w, {"ax_sample": 0, "ax_coord": 1})
145+
146+
def transform(self, x_full):
147+
VTheta_list = super().transform(x_full) # list of (K, n_features)
148+
if self._L_chol is None:
149+
return VTheta_list
150+
out = []
151+
for VTheta in VTheta_list:
152+
A = np.asarray(VTheta)
153+
A_w = self._apply_whitener(A) # (K, m)
154+
out.append(AxesArray(A_w, {"ax_sample": 0, "ax_coord": 1}))
155+
return out
156+

0 commit comments

Comments
 (0)