@@ -43,14 +43,23 @@ def _mean_correction(samples, mu):
4343 return samples
4444
4545
46- def _fast_cholesky_covariance_correction ( samples , V , D ):
46+ def _transform_with_fast_cholesky_covariance_correction ( grid , cov ):
4747 # see [JAIF23_Frisch] V.E
4848
49+
50+ ew , V = np .linalg .eig (cov )
51+ D = np .diag (np .sqrt (ew ))
52+
53+ eps = 1e-9
54+ grid = np .clip (grid , eps , 1 - eps ) # avoid inf in ppf
55+ x_std = norm .ppf (grid )
56+ x_std = x_std - np .mean (x_std , axis = 0 , keepdims = True )
57+
4958 # variance correction
50- L = samples .shape [0 ]
51- v_d = 1 / L * np .sum (samples ** 2 , axis = 0 ) # shape (dim,)
59+ L = x_std .shape [0 ]
60+ v_d = 1 / L * np .sum (x_std ** 2 , axis = 0 ) # shape (dim,)
5261
53- X_stdD = samples / np .sqrt (v_d )
62+ X_stdD = x_std / np .sqrt (v_d )
5463
5564 # Fast Cholesky Covariance Correction
5665 C_stdD = 1 / L * (X_stdD .T @ X_stdD )
@@ -59,7 +68,7 @@ def _fast_cholesky_covariance_correction(samples, V, D):
5968 L_stdD_inv = np .linalg .inv (L_stdD )
6069 except np .linalg .LinAlgError :
6170 # In case L_stdD non PD (for example if C is almost 0, or other numerical issues), skip the correction
62- return samples
71+ return None
6372
6473 X_Gauss = V @ D @ L_stdD_inv @ X_stdD .T # (dim,dim) @ (dim,dim) @ (dim,dim) @ (dim,L) -> (dim,L)
6574 X_Gauss = X_Gauss .T # (L,dim)
@@ -102,11 +111,16 @@ def sample_gaussian_fibonacci(mu: list | np.ndarray, cov: np.ndarray, sample_cou
102111 dim = mu .shape [0 ]
103112 grid = get_uniform_grid (dim , sample_count , type )
104113
105- samples , V , D = _transform_grid_gaussian ( grid , mu , cov )
114+
106115
107116 # center for fast cholesky correction
108117 if sample_count > 1 :
109- samples = samples - np .mean (samples , axis = 0 )
110- samples = _fast_cholesky_covariance_correction (samples , V , D )
118+ samples = _transform_with_fast_cholesky_covariance_correction (grid , cov )
119+ if samples is None :
120+ # fallback to eigen decomposition method
121+ samples , V , D = _transform_grid_gaussian (grid , mu , cov )
122+ return samples
111123 samples = _mean_correction (samples , mu )
124+ else :
125+ samples , V , D = _transform_grid_gaussian (grid , mu , cov )
112126 return samples
0 commit comments