|
| 1 | +import numpy as np |
| 2 | +from .weak_pde_library import WeakPDELibrary |
| 3 | +from ..utils import AxesArray |
| 4 | + |
| 5 | + |
| 6 | +class WeightedWeakPDELibrary(WeakPDELibrary): |
| 7 | + """ |
| 8 | + WeakPDELibrary with GLS whitening via a Cholesky factor built from the |
| 9 | + variance field on the spatiotemporal grid. |
| 10 | +
|
| 11 | + Parameters |
| 12 | + ---------- |
| 13 | + spatiotemporal_weights : ndarray, shape = spatiotemporal_grid.shape[:-1] |
| 14 | + Pointwise noise variances σ^2 on the grid (no feature axis). |
| 15 | + The covariance of the weak RHS is Cov[V] = M_y diag(σ^2) M_y^T. |
| 16 | +
|
| 17 | + Notes |
| 18 | + ----- |
| 19 | + The whitener W = L^{-1}, with L L^T = Cov[V], is left-applied to both Θ and V. |
| 20 | + This implements min_x || W(Θ x - V) ||_2^2, i.e., GLS in the weak space. |
| 21 | + """ |
| 22 | + |
| 23 | + def __init__(self, *args, spatiotemporal_weights=None, **kwargs): |
| 24 | + self.spatiotemporal_weights = spatiotemporal_weights |
| 25 | + self._L_chol = None # lower-triangular Cholesky factor of Cov[V] |
| 26 | + super().__init__(*args, **kwargs) |
| 27 | + |
| 28 | + # ------------------------------ core whitening ------------------------------ |
| 29 | + |
| 30 | + def _build_whitener_from_variance(self): |
| 31 | + """ |
| 32 | + Construct L such that Cov[V] = L L^T with |
| 33 | + Cov[V]_{kℓ} = sum_{g ∈ grid} ( w_k[g] * w_ℓ[g] * σ^2[g] ), |
| 34 | + where w_k are the time-derivative weak weights on domain k. |
| 35 | + """ |
| 36 | + if self.spatiotemporal_weights is None: |
| 37 | + self._L_chol = None |
| 38 | + return |
| 39 | + |
| 40 | + # --- robust weight-field shape handling --- |
| 41 | + base_grid = np.asarray(self.spatiotemporal_grid) |
| 42 | + expected = tuple(base_grid.shape[:-1]) # e.g. (Nx, Nt) for a 2D grid |
| 43 | + var_grid = np.asarray(self.spatiotemporal_weights) |
| 44 | + |
| 45 | + if var_grid.shape == expected + (1,): |
| 46 | + var_grid = var_grid[..., 0] |
| 47 | + elif var_grid.shape != expected: |
| 48 | + raise ValueError( |
| 49 | + f"spatiotemporal_weights must have shape {expected} or {expected + (1,)}, " |
| 50 | + f"got {var_grid.shape}" |
| 51 | + ) |
| 52 | + |
| 53 | + # Flattened variance for convenient indexing |
| 54 | + var_flat = var_grid.ravel(order="C") |
| 55 | + grid_shape = expected |
| 56 | + K = self.K |
| 57 | + |
| 58 | + idx_lists = [] |
| 59 | + val_lists = [] |
| 60 | + for k in range(K): |
| 61 | + # local multi-index grids (can be 1D, 2D, 3D… arrays) |
| 62 | + inds_axes = [np.asarray(ax, dtype=np.intp) for ax in self.inds_k[k]] |
| 63 | + grids = np.meshgrid(*inds_axes, indexing="ij") |
| 64 | + |
| 65 | + # linearize to 1D! |
| 66 | + lin_idx = np.ravel_multi_index(tuple(grids), dims=grid_shape, order="C") |
| 67 | + lin_idx = lin_idx.ravel(order="C") |
| 68 | + |
| 69 | + # corresponding weak RHS weights, flattened to 1D |
| 70 | + wk = np.asarray(self.fulltweights[k], dtype=float).ravel(order="C") |
| 71 | + |
| 72 | + # ensure same length (paranoia check) |
| 73 | + if wk.shape[0] != lin_idx.shape[0]: |
| 74 | + raise RuntimeError( |
| 75 | + f"Weight/variance size mismatch on cell {k}: " |
| 76 | + f"wk has {wk.shape[0]} entries, indices have {lin_idx.shape[0]}" |
| 77 | + ) |
| 78 | + |
| 79 | + vals = wk * np.sqrt(var_flat[lin_idx]) |
| 80 | + |
| 81 | + idx_lists.append(lin_idx) |
| 82 | + val_lists.append(vals) |
| 83 | + |
| 84 | + # Build Cov[V] = B B^T with B_{k,i} = w_k[i] * sqrt(var[i]) |
| 85 | + Cov = np.zeros((K, K), dtype=float) |
| 86 | + for k in range(K): |
| 87 | + vk = val_lists[k] |
| 88 | + Cov[k, k] = np.dot(vk, vk) |
| 89 | + # off-diagonals via set intersection of supports |
| 90 | + idx_k = idx_lists[k] |
| 91 | + # Use a dict for fast overlap accumulation |
| 92 | + map_k = dict(zip(idx_k.tolist(), vk.tolist())) |
| 93 | + for ell in range(k + 1, K): |
| 94 | + s = 0.0 |
| 95 | + idx_e = idx_lists[ell] |
| 96 | + v_e = val_lists[ell] |
| 97 | + map_e = dict(zip(idx_e.tolist(), v_e.tolist())) |
| 98 | + # iterate the smaller map |
| 99 | + if len(map_k) <= len(map_e): |
| 100 | + for j, vkj in map_k.items(): |
| 101 | + ve = map_e.get(j) |
| 102 | + if ve is not None: |
| 103 | + s += vkj * ve |
| 104 | + else: |
| 105 | + for j, ve in map_e.items(): |
| 106 | + vk_j = map_k.get(j) |
| 107 | + if vk_j is not None: |
| 108 | + s += vk_j * ve |
| 109 | + Cov[k, ell] = s |
| 110 | + Cov[ell, k] = s |
| 111 | + |
| 112 | + # diagonal nugget for stability |
| 113 | + avg_diag = np.trace(Cov) / max(K, 1) |
| 114 | + nugget = 1e-12 * avg_diag |
| 115 | + Cov.flat[:: K + 1] += nugget |
| 116 | + |
| 117 | + # robust Cholesky with fallback if needed |
| 118 | + try: |
| 119 | + self._L_chol = np.linalg.cholesky(Cov) |
| 120 | + except np.linalg.LinAlgError: |
| 121 | + # inflate nugget and retry once |
| 122 | + Cov.flat[:: K + 1] += max(1e-10, 1e-6 * avg_diag) |
| 123 | + self._L_chol = np.linalg.cholesky(Cov) |
| 124 | + |
| 125 | + def _apply_whitener(self, A): |
| 126 | + """Return L^{-1} A without forming L^{-1} explicitly.""" |
| 127 | + if self._L_chol is None: |
| 128 | + return A |
| 129 | + # solve L X = A → X = L^{-1} A |
| 130 | + return np.linalg.solve(self._L_chol, A) |
| 131 | + |
| 132 | + # ------------------------------ hooks ------------------------------ |
| 133 | + |
| 134 | + def _weak_form_setup(self): |
| 135 | + # parent builds inds_k and the weak weight tensors |
| 136 | + super()._weak_form_setup() |
| 137 | + # then build the GLS whitener from the variance field |
| 138 | + if self.spatiotemporal_weights is not None: |
| 139 | + self._build_whitener_from_variance() |
| 140 | + |
| 141 | + def convert_u_dot_integral(self, u): |
| 142 | + Vy = super().convert_u_dot_integral(u) # (K, 1) |
| 143 | + Vy_w = self._apply_whitener(np.asarray(Vy)) |
| 144 | + return AxesArray(Vy_w, {"ax_sample": 0, "ax_coord": 1}) |
| 145 | + |
| 146 | + def transform(self, x_full): |
| 147 | + VTheta_list = super().transform(x_full) # list of (K, n_features) |
| 148 | + if self._L_chol is None: |
| 149 | + return VTheta_list |
| 150 | + out = [] |
| 151 | + for VTheta in VTheta_list: |
| 152 | + A = np.asarray(VTheta) |
| 153 | + A_w = self._apply_whitener(A) # (K, m) |
| 154 | + out.append(AxesArray(A_w, {"ax_sample": 0, "ax_coord": 1})) |
| 155 | + return out |
| 156 | + |
0 commit comments