Skip to content

Commit 23b2804

Browse files
committed
fix transform with covariance correction
1 parent 240d136 commit 23b2804

File tree

2 files changed

+26
-11
lines changed

2 files changed

+26
-11
lines changed

src/deterministic_gaussian_sampling_fibonacci/sample_gaus.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,23 @@ def _mean_correction(samples, mu):
4343
return samples
4444

4545

46-
def _fast_cholesky_covariance_correction(samples, V, D):
46+
def _transform_with_fast_cholesky_covariance_correction(grid, cov):
4747
# see [JAIF23_Frisch] V.E
4848

49+
50+
ew, V = np.linalg.eig(cov)
51+
D = np.diag(np.sqrt(ew))
52+
53+
eps = 1e-9
54+
grid = np.clip(grid, eps, 1 - eps) # avoid inf in ppf
55+
x_std = norm.ppf(grid)
56+
x_std = x_std - np.mean(x_std, axis=0, keepdims=True)
57+
4958
# variance correction
50-
L = samples.shape[0]
51-
v_d = 1 / L * np.sum(samples**2, axis=0) # shape (dim,)
59+
L = x_std.shape[0]
60+
v_d = 1 / L * np.sum(x_std**2, axis=0) # shape (dim,)
5261

53-
X_stdD = samples / np.sqrt(v_d)
62+
X_stdD = x_std / np.sqrt(v_d)
5463

5564
# Fast Cholesky Covariance Correction
5665
C_stdD = 1 / L * (X_stdD.T @ X_stdD)
@@ -59,7 +68,7 @@ def _fast_cholesky_covariance_correction(samples, V, D):
5968
L_stdD_inv = np.linalg.inv(L_stdD)
6069
except np.linalg.LinAlgError:
6170
# In case L_stdD non PD (for example if C is almost 0, or other numerical issues), skip the correction
62-
return samples
71+
return None
6372

6473
X_Gauss = V @ D @ L_stdD_inv @ X_stdD.T # (dim,dim) @ (dim,dim) @ (dim,dim) @ (dim,L) -> (dim,L)
6574
X_Gauss = X_Gauss.T # (L,dim)
@@ -102,11 +111,16 @@ def sample_gaussian_fibonacci(mu: list | np.ndarray, cov: np.ndarray, sample_cou
102111
dim = mu.shape[0]
103112
grid = get_uniform_grid(dim, sample_count, type)
104113

105-
samples, V, D = _transform_grid_gaussian(grid, mu, cov)
114+
106115

107116
# center for fast cholesky correction
108117
if sample_count > 1:
109-
samples = samples - np.mean(samples, axis=0)
110-
samples = _fast_cholesky_covariance_correction(samples, V, D)
118+
samples = _transform_with_fast_cholesky_covariance_correction(grid, cov)
119+
if samples is None:
120+
# fallback to eigen decomposition method
121+
samples, V, D = _transform_grid_gaussian(grid, mu, cov)
122+
return samples
111123
samples = _mean_correction(samples, mu)
124+
else:
125+
samples, V, D = _transform_grid_gaussian(grid, mu, cov)
112126
return samples

tests/sample_gaus_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ def test_001_gaus_sample_test(type, dim):
2424
mu_pred = np.mean(samp, axis=0) # shape (3,)
2525
C_pred = np.cov(samp, rowvar=False, bias=True)
2626

27-
assert np.all(np.isclose(mu, mu_pred, 10**-2))
28-
if sampcount > 10000:
29-
assert np.all(np.isclose(cov, C_pred, 10**-2))
27+
assert np.all(np.isclose(mu, mu_pred, 10**-15))
28+
assert np.all(np.isclose(cov, C_pred, 10**-15))
29+
30+
3031

3132
@pytest.mark.parametrize("type", FIB_TYPES)
3233
@pytest.mark.parametrize("dim", FIB_DIMS)

0 commit comments

Comments
 (0)