|
| 1 | +""" |
| 2 | +Constrained HMC sampler for GGM precision matrices. |
| 3 | +
|
| 4 | +This script is called from R via subprocess to sample precision matrices |
| 5 | +with specified zero constraints using the mici library. |
| 6 | +
|
| 7 | +Usage: python run_constrained_hmc.py <input_json_file> |
| 8 | +""" |
| 9 | + |
| 10 | +import json |
| 11 | +import sys |
| 12 | +import time |
| 13 | +import os |
| 14 | +import numpy as np |
| 15 | +import mici |
| 16 | +import jax |
| 17 | +from numpyro.distributions.transforms import LowerCholeskyTransform |
| 18 | +import arviz # For diagnostics |
| 19 | + |
| 20 | +# Force CPU backend to avoid CUDA autotuner issues with small matrices |
| 21 | +# Comment out this line to use GPU (may require CUDA driver updates) |
| 22 | +jax.config.update("jax_default_device", jax.devices("cpu")[0]) |
| 23 | + |
| 24 | +jax.config.update("jax_enable_x64", True) |
| 25 | + |
| 26 | + |
| 27 | +def main(): |
| 28 | + # Load input data from R |
| 29 | + with open(sys.argv[1], "r") as f: |
| 30 | + data = json.load(f) |
| 31 | + |
| 32 | + n_variable = data["n_variable"] |
| 33 | + n_obs = data["n_obs"] |
| 34 | + S = np.array(data["scatter_matrix"], dtype=np.float64) |
| 35 | + zero_indices = np.array(data["zero_indices"], dtype=np.int64) |
| 36 | + n_warm_up_iter = data["n_warm_up_iter"] |
| 37 | + n_main_iter = data["n_main_iter"] |
| 38 | + n_chain = data["n_chain"] |
| 39 | + seed = data["seed"] |
| 40 | + output_file = data["output_file"] |
| 41 | + samples_file = data["samples_file"] |
| 42 | + |
| 43 | + print(f"Loaded data: n_obs={n_obs}, n_variable={n_variable}, n_zero_pairs={len(zero_indices)}") |
| 44 | + print(f"Chains: {n_chain}, Warmup: {n_warm_up_iter}, Samples: {n_main_iter}") |
| 45 | + sys.stdout.flush() |
| 46 | + |
| 47 | + # Precompute Cholesky factor of scatter matrix for efficient trace computation |
| 48 | + # tr(Omega S) = tr(L L^T S) = ||S_chol^T @ L||_F^2 |
| 49 | + S_chol = np.linalg.cholesky(S) |
| 50 | + |
| 51 | + # Set up transformations |
| 52 | + vector_to_cholesky = LowerCholeskyTransform() |
| 53 | + cholesky_to_vector = vector_to_cholesky.inv |
| 54 | + |
| 55 | + def constr(u, zero_indices): |
| 56 | + L = vector_to_cholesky(u) |
| 57 | + return jax.vmap(lambda i, j: L[i] @ L[j])(*zero_indices.T) |
| 58 | + |
| 59 | + def neg_log_dens(u, n_obs, S_chol): |
| 60 | + """ |
| 61 | + Negative log posterior for GGM precision matrix. |
| 62 | +
|
| 63 | + log p(Omega | X) ∝ (n/2) * log|Omega| - (1/2) * tr(Omega * S) + log p(Omega) |
| 64 | +
|
| 65 | + where S = X'X is the scatter matrix. |
| 66 | + We use a standard normal prior on the unconstrained parameters u. |
| 67 | +
|
| 68 | + Optimizations: |
| 69 | + - Avoid forming Omega = L @ L.T explicitly |
| 70 | + - Use tr(L L^T S) = ||S_chol^T @ L||_F^2 where S = S_chol @ S_chol^T |
| 71 | + - This reduces from 2 O(p³) ops to 1 O(p³) op |
| 72 | + """ |
| 73 | + L = vector_to_cholesky(u) |
| 74 | + |
| 75 | + # Log determinant: log|Omega| = 2 * sum(log(diag(L))) |
| 76 | + log_det_Omega = 2 * jax.numpy.log(jax.numpy.diag(L)).sum() |
| 77 | + |
| 78 | + # Trace term: tr(Omega * S) = tr(L L^T S) = ||S_chol^T @ L||_F^2 |
| 79 | + # This avoids explicitly forming Omega = L @ L.T |
| 80 | + S_chol_T_L = S_chol.T @ L |
| 81 | + trace_term = jax.numpy.sum(S_chol_T_L ** 2) |
| 82 | + |
| 83 | + # Gaussian likelihood: (n/2) * log|Omega| - (1/2) * tr(Omega * S) |
| 84 | + log_likelihood = (n_obs / 2) * log_det_Omega - 0.5 * trace_term |
| 85 | + |
| 86 | + # Prior on unconstrained parameters (standard normal) |
| 87 | + log_prior = -0.5 * (u**2).sum() |
| 88 | + |
| 89 | + # Return negative log posterior |
| 90 | + return -(log_likelihood + log_prior) |
| 91 | + |
| 92 | + def trace_func(state): |
| 93 | + L = vector_to_cholesky(state.pos) |
| 94 | + return {"u": state.pos, "P": L @ L.T} |
| 95 | + |
| 96 | + # Create initial states |
| 97 | + rng = np.random.default_rng(seed) |
| 98 | + scale = 0.01 |
| 99 | + |
| 100 | + init_states = [] |
| 101 | + for c in range(n_chain): |
| 102 | + random_matrix = rng.standard_normal((n_variable, n_variable)) |
| 103 | + P_init = np.identity(n_variable) + scale * random_matrix @ random_matrix.T |
| 104 | + P_init[zero_indices[:, 0], zero_indices[:, 1]] = 0.0 |
| 105 | + P_init[zero_indices[:, 1], zero_indices[:, 0]] = 0.0 |
| 106 | + L_init = np.linalg.cholesky(P_init) |
| 107 | + u_init = np.asarray(cholesky_to_vector(L_init)) |
| 108 | + assert not np.any(np.isnan(u_init)), "NaN in initial state" |
| 109 | + assert abs(constr(u_init, zero_indices).max()) < 1e-8, "Constraint violation" |
| 110 | + init_states.append(u_init) |
| 111 | + |
| 112 | + print(f"Created {len(init_states)} initial states") |
| 113 | + print("Running constrained HMC sampling...") |
| 114 | + sys.stdout.flush() |
| 115 | + |
| 116 | + # Time the sampling |
| 117 | + start_time = time.perf_counter() |
| 118 | + |
| 119 | + # Run sampling |
| 120 | + results = mici.sample_constrained_hmc_chains( |
| 121 | + n_warm_up_iter=n_warm_up_iter, |
| 122 | + n_main_iter=n_main_iter, |
| 123 | + init_states=init_states, |
| 124 | + neg_log_dens=lambda u: neg_log_dens(u, n_obs, S_chol), |
| 125 | + constr=lambda u: constr(u, zero_indices), |
| 126 | + backend="jax", |
| 127 | + seed=rng, |
| 128 | + monitor_stats=("accept_stat", "n_step", "step_size"), |
| 129 | + trace_funcs=[trace_func], |
| 130 | + n_worker=1, |
| 131 | + use_thread_pool=False, |
| 132 | + ) |
| 133 | + |
| 134 | + end_time = time.perf_counter() |
| 135 | + sampling_duration = end_time - start_time |
| 136 | + |
| 137 | + print(f"\nSampling complete! Duration: {sampling_duration:.2f} seconds") |
| 138 | + sys.stdout.flush() |
| 139 | + |
| 140 | + |
| 141 | + ess = arviz.ess(results.traces, var_names=["u"]) |
| 142 | + r_hat = arviz.rhat(results.traces, var_names=["u"]) |
| 143 | + |
| 144 | + min_ess = float(ess.min().u.data) |
| 145 | + max_rhat = float(r_hat.max().u.data) |
| 146 | + |
| 147 | + # Extract posterior samples of precision matrix |
| 148 | + # Shape: (n_chain, n_main_iter, n_variable, n_variable) |
| 149 | + P_samples = np.array(results.traces["P"]) |
| 150 | + P_mean = P_samples.mean(axis=(0, 1)) |
| 151 | + |
| 152 | + print(f"Min ESS: {min_ess:.1f}, Max R-hat: {max_rhat:.3f}") |
| 153 | + print(f"P_samples shape: {P_samples.shape}") |
| 154 | + sys.stdout.flush() |
| 155 | + |
| 156 | + # Save full posterior samples as numpy array |
| 157 | + # Reshape to 2D for RcppCNPy compatibility (doesn't support 4D arrays) |
| 158 | + # Original shape: (n_chain, n_main_iter, n_variable, n_variable) |
| 159 | + # Saved shape: (n_chain * n_main_iter, n_variable * n_variable) |
| 160 | + P_samples_flat = P_samples.reshape(-1, n_variable * n_variable) |
| 161 | + np.save(samples_file, P_samples_flat) |
| 162 | + print(f"Posterior samples saved to: {samples_file} (flattened shape: {P_samples_flat.shape})") |
| 163 | + |
| 164 | + # Helper to convert NaN/Inf to None for JSON compatibility |
| 165 | + def sanitize_for_json(val): |
| 166 | + if isinstance(val, float) and (np.isnan(val) or np.isinf(val)): |
| 167 | + return None |
| 168 | + return val |
| 169 | + |
| 170 | + # Save summary results as JSON |
| 171 | + output_data = { |
| 172 | + "min_ess": sanitize_for_json(min_ess), |
| 173 | + "max_rhat": sanitize_for_json(max_rhat), |
| 174 | + "P_mean": P_mean.tolist(), |
| 175 | + "n_chain": n_chain, |
| 176 | + "n_iter": n_main_iter, |
| 177 | + "samples_file": samples_file, |
| 178 | + "sampling_duration_seconds": sampling_duration, |
| 179 | + } |
| 180 | + |
| 181 | + with open(output_file, "w") as f: |
| 182 | + json.dump(output_data, f) |
| 183 | + |
| 184 | + print(f"Results saved to: {output_file}") |
| 185 | + |
| 186 | + |
| 187 | +if __name__ == "__main__": |
| 188 | + main() |
0 commit comments