Skip to content

Commit 1d18201

Browse files
committed
add comparison with hmc
1 parent b90046c commit 1d18201

File tree

13 files changed

+624
-85
lines changed

13 files changed

+624
-85
lines changed

R/RcppExports.R

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,6 @@ run_bgm_parallel <- function(observations, num_categories, pairwise_scale, edge_
99
.Call(`_bgms_run_bgm_parallel`, observations, num_categories, pairwise_scale, edge_prior, inclusion_probability, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, interaction_index_matrix, iter, warmup, counts_per_category, blume_capel_stats, main_alpha, main_beta, na_impute, missing_index, is_ordinal_variable, baseline_category, edge_selection, update_method, pairwise_effect_indices, target_accept, pairwise_stats, hmc_num_leapfrogs, nuts_max_depth, learn_mass_matrix, num_chains, nThreads, seed, progress_type)
1010
}
1111

12-
<<<<<<< HEAD
13-
=======
14-
chol_update_arma <- function(R, u, downdate = FALSE, eps = 1e-12) {
15-
.Call(`_bgms_chol_update_arma`, R, u, downdate, eps)
16-
}
17-
18-
get_explog_switch <- function() {
19-
.Call(`_bgms_get_explog_switch`)
20-
}
21-
22-
rcpp_ieee754_exp <- function(x) {
23-
.Call(`_bgms_rcpp_ieee754_exp`, x)
24-
}
25-
26-
rcpp_ieee754_log <- function(x) {
27-
.Call(`_bgms_rcpp_ieee754_log`, x)
28-
}
29-
30-
>>>>>>> 7252076 (ggm compiles but runtime has a weird error)
3112
sample_omrf_gibbs <- function(no_states, no_variables, no_categories, interactions, thresholds, iter) {
3213
.Call(`_bgms_sample_omrf_gibbs`, no_states, no_variables, no_categories, interactions, thresholds, iter)
3314
}

dev/ggm-hmc/README.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# GGM Constrained HMC Sampling
2+
3+
This folder contains scripts for sampling precision matrices for Gaussian Graphical Models (GGMs) using constrained Hamiltonian Monte Carlo (HMC). The constrained HMC approach enforces exact zeros in the precision matrix (corresponding to missing edges in the graph) as hard constraints.
4+
5+
## Overview
6+
7+
- `run_constrained_hmc.py` - Python script that performs the constrained HMC sampling using JAX and [mici](https://github.com/matt-graham/mici)
8+
- `run_constrained_hmc_subprocess.R` - R script that generates data, calls the Python sampler via subprocess, and processes results
9+
10+
## Prerequisites
11+
12+
### 1. Clone the mici example repository
13+
14+
The Python environment is defined in a separate repository. Clone it as a sibling folder:
15+
16+
```bash
17+
cd /path/to/bgms/dev
18+
git clone https://github.com/matt-graham/ggm-precision-constrained-hmc
19+
```
20+
21+
Your folder structure should look like:
22+
```
23+
bgms/dev/
24+
├── ggm-hmc/ # This folder
25+
│ ├── README.md
26+
│ ├── run_constrained_hmc.py
27+
│ └── run_constrained_hmc_subprocess.R
28+
└── ggm-precision-constrained-hmc/ # Cloned repo with uv environment
29+
├── .venv/
30+
├── pyproject.toml
31+
└── ...
32+
```
33+
34+
**Alternative:** If you prefer a different location, adjust the `python_dir` path in `run_constrained_hmc_subprocess.R`.
35+
36+
### 2. Set up the Python environment with uv
37+
38+
Install [uv](https://docs.astral.sh/uv/) if you don't have it:
39+
40+
```bash
41+
curl -LsSf https://astral.sh/uv/install.sh | sh
42+
```
43+
44+
Then create and sync the Python environment:
45+
46+
```bash
47+
cd /path/to/bgms/dev/ggm-precision-constrained-hmc
48+
uv sync
49+
```
50+
51+
This will create a `.venv` folder with Python 3.14 (free-threaded) and all required dependencies.
52+
53+
## R Package Dependencies
54+
55+
The R script requires:
56+
- `bgms` (this package)
57+
- `mvtnorm`
58+
- `jsonlite`
59+
- `RcppCNPy` (for reading numpy arrays)
60+
- `BDgraph` (for generating G-Wishart samples)
61+
62+
```r
63+
install.packages(c("mvtnorm", "jsonlite", "RcppCNPy", "BDgraph"))
64+
```
65+
66+
## Usage
67+
68+
```r
69+
# Set working directory to bgms root
70+
setwd("/path/to/bgms")
71+
72+
# Run the script
73+
source("dev/ggm-hmc/run_constrained_hmc_subprocess.R")
74+
```
75+
76+
The script will:
77+
1. Generate simulated data from a sparse GGM
78+
2. Run the bgms MH sampler for comparison
79+
3. Call the Python constrained HMC sampler
80+
4. Load and summarize the results
81+

dev/ggm-hmc/run_constrained_hmc.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
"""
2+
Constrained HMC sampler for GGM precision matrices.
3+
4+
This script is called from R via subprocess to sample precision matrices
5+
with specified zero constraints using the mici library.
6+
7+
Usage: python run_constrained_hmc.py <input_json_file>
8+
"""
9+
10+
import json
11+
import sys
12+
import time
13+
import os
14+
import numpy as np
15+
import mici
16+
import jax
17+
from numpyro.distributions.transforms import LowerCholeskyTransform
18+
import arviz # For diagnostics
19+
20+
# Force CPU backend to avoid CUDA autotuner issues with small matrices
21+
# Comment out this line to use GPU (may require CUDA driver updates)
22+
jax.config.update("jax_default_device", jax.devices("cpu")[0])
23+
24+
jax.config.update("jax_enable_x64", True)
25+
26+
27+
def main():
28+
# Load input data from R
29+
with open(sys.argv[1], "r") as f:
30+
data = json.load(f)
31+
32+
n_variable = data["n_variable"]
33+
n_obs = data["n_obs"]
34+
S = np.array(data["scatter_matrix"], dtype=np.float64)
35+
zero_indices = np.array(data["zero_indices"], dtype=np.int64)
36+
n_warm_up_iter = data["n_warm_up_iter"]
37+
n_main_iter = data["n_main_iter"]
38+
n_chain = data["n_chain"]
39+
seed = data["seed"]
40+
output_file = data["output_file"]
41+
samples_file = data["samples_file"]
42+
43+
print(f"Loaded data: n_obs={n_obs}, n_variable={n_variable}, n_zero_pairs={len(zero_indices)}")
44+
print(f"Chains: {n_chain}, Warmup: {n_warm_up_iter}, Samples: {n_main_iter}")
45+
sys.stdout.flush()
46+
47+
# Precompute Cholesky factor of scatter matrix for efficient trace computation
48+
# tr(Omega S) = tr(L L^T S) = ||S_chol^T @ L||_F^2
49+
S_chol = np.linalg.cholesky(S)
50+
51+
# Set up transformations
52+
vector_to_cholesky = LowerCholeskyTransform()
53+
cholesky_to_vector = vector_to_cholesky.inv
54+
55+
def constr(u, zero_indices):
56+
L = vector_to_cholesky(u)
57+
return jax.vmap(lambda i, j: L[i] @ L[j])(*zero_indices.T)
58+
59+
def neg_log_dens(u, n_obs, S_chol):
60+
"""
61+
Negative log posterior for GGM precision matrix.
62+
63+
log p(Omega | X) ∝ (n/2) * log|Omega| - (1/2) * tr(Omega * S) + log p(Omega)
64+
65+
where S = X'X is the scatter matrix.
66+
We use a standard normal prior on the unconstrained parameters u.
67+
68+
Optimizations:
69+
- Avoid forming Omega = L @ L.T explicitly
70+
- Use tr(L L^T S) = ||S_chol^T @ L||_F^2 where S = S_chol @ S_chol^T
71+
- This reduces from 2 O(p³) ops to 1 O(p³) op
72+
"""
73+
L = vector_to_cholesky(u)
74+
75+
# Log determinant: log|Omega| = 2 * sum(log(diag(L)))
76+
log_det_Omega = 2 * jax.numpy.log(jax.numpy.diag(L)).sum()
77+
78+
# Trace term: tr(Omega * S) = tr(L L^T S) = ||S_chol^T @ L||_F^2
79+
# This avoids explicitly forming Omega = L @ L.T
80+
S_chol_T_L = S_chol.T @ L
81+
trace_term = jax.numpy.sum(S_chol_T_L ** 2)
82+
83+
# Gaussian likelihood: (n/2) * log|Omega| - (1/2) * tr(Omega * S)
84+
log_likelihood = (n_obs / 2) * log_det_Omega - 0.5 * trace_term
85+
86+
# Prior on unconstrained parameters (standard normal)
87+
log_prior = -0.5 * (u**2).sum()
88+
89+
# Return negative log posterior
90+
return -(log_likelihood + log_prior)
91+
92+
def trace_func(state):
93+
L = vector_to_cholesky(state.pos)
94+
return {"u": state.pos, "P": L @ L.T}
95+
96+
# Create initial states
97+
rng = np.random.default_rng(seed)
98+
scale = 0.01
99+
100+
init_states = []
101+
for c in range(n_chain):
102+
random_matrix = rng.standard_normal((n_variable, n_variable))
103+
P_init = np.identity(n_variable) + scale * random_matrix @ random_matrix.T
104+
P_init[zero_indices[:, 0], zero_indices[:, 1]] = 0.0
105+
P_init[zero_indices[:, 1], zero_indices[:, 0]] = 0.0
106+
L_init = np.linalg.cholesky(P_init)
107+
u_init = np.asarray(cholesky_to_vector(L_init))
108+
assert not np.any(np.isnan(u_init)), "NaN in initial state"
109+
assert abs(constr(u_init, zero_indices).max()) < 1e-8, "Constraint violation"
110+
init_states.append(u_init)
111+
112+
print(f"Created {len(init_states)} initial states")
113+
print("Running constrained HMC sampling...")
114+
sys.stdout.flush()
115+
116+
# Time the sampling
117+
start_time = time.perf_counter()
118+
119+
# Run sampling
120+
results = mici.sample_constrained_hmc_chains(
121+
n_warm_up_iter=n_warm_up_iter,
122+
n_main_iter=n_main_iter,
123+
init_states=init_states,
124+
neg_log_dens=lambda u: neg_log_dens(u, n_obs, S_chol),
125+
constr=lambda u: constr(u, zero_indices),
126+
backend="jax",
127+
seed=rng,
128+
monitor_stats=("accept_stat", "n_step", "step_size"),
129+
trace_funcs=[trace_func],
130+
n_worker=1,
131+
use_thread_pool=False,
132+
)
133+
134+
end_time = time.perf_counter()
135+
sampling_duration = end_time - start_time
136+
137+
print(f"\nSampling complete! Duration: {sampling_duration:.2f} seconds")
138+
sys.stdout.flush()
139+
140+
141+
ess = arviz.ess(results.traces, var_names=["u"])
142+
r_hat = arviz.rhat(results.traces, var_names=["u"])
143+
144+
min_ess = float(ess.min().u.data)
145+
max_rhat = float(r_hat.max().u.data)
146+
147+
# Extract posterior samples of precision matrix
148+
# Shape: (n_chain, n_main_iter, n_variable, n_variable)
149+
P_samples = np.array(results.traces["P"])
150+
P_mean = P_samples.mean(axis=(0, 1))
151+
152+
print(f"Min ESS: {min_ess:.1f}, Max R-hat: {max_rhat:.3f}")
153+
print(f"P_samples shape: {P_samples.shape}")
154+
sys.stdout.flush()
155+
156+
# Save full posterior samples as numpy array
157+
# Reshape to 2D for RcppCNPy compatibility (doesn't support 4D arrays)
158+
# Original shape: (n_chain, n_main_iter, n_variable, n_variable)
159+
# Saved shape: (n_chain * n_main_iter, n_variable * n_variable)
160+
P_samples_flat = P_samples.reshape(-1, n_variable * n_variable)
161+
np.save(samples_file, P_samples_flat)
162+
print(f"Posterior samples saved to: {samples_file} (flattened shape: {P_samples_flat.shape})")
163+
164+
# Helper to convert NaN/Inf to None for JSON compatibility
165+
def sanitize_for_json(val):
166+
if isinstance(val, float) and (np.isnan(val) or np.isinf(val)):
167+
return None
168+
return val
169+
170+
# Save summary results as JSON
171+
output_data = {
172+
"min_ess": sanitize_for_json(min_ess),
173+
"max_rhat": sanitize_for_json(max_rhat),
174+
"P_mean": P_mean.tolist(),
175+
"n_chain": n_chain,
176+
"n_iter": n_main_iter,
177+
"samples_file": samples_file,
178+
"sampling_duration_seconds": sampling_duration,
179+
}
180+
181+
with open(output_file, "w") as f:
182+
json.dump(output_data, f)
183+
184+
print(f"Results saved to: {output_file}")
185+
186+
187+
if __name__ == "__main__":
188+
main()

0 commit comments

Comments
 (0)