Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
38e4efe
add auto-marginalization implementation to the main package
sunxd3 Aug 23, 2025
a5aeb6b
format
sunxd3 Aug 23, 2025
b4eb8d5
add import
sunxd3 Aug 23, 2025
fbfac76
import Exponential distribution
sunxd3 Aug 23, 2025
3c0504e
using LogExpFunctions in tests
sunxd3 Aug 23, 2025
6b1206e
test on GMM
sunxd3 Aug 23, 2025
7243c23
fix some errors
sunxd3 Aug 26, 2025
2bc9e4b
add hmm sanity check
sunxd3 Aug 28, 2025
ad90ec8
format some example files
sunxd3 Aug 28, 2025
6de0807
fix error
sunxd3 Aug 29, 2025
0445a28
add sampling tests
sunxd3 Aug 29, 2025
f06e2f1
formatting
sunxd3 Aug 29, 2025
184434c
Update JuliaBUGS/test/model/auto_marginalization_sampling.jl
sunxd3 Aug 29, 2025
bc107b2
Update JuliaBUGS/test/model/auto_marginalization.jl
sunxd3 Aug 29, 2025
92a907b
Update JuliaBUGS/test/model/auto_marginalization.jl
sunxd3 Aug 29, 2025
434afd1
Merge branch 'main' into sunxd/auto_marginalization
sunxd3 Aug 29, 2025
e812006
fix test error
sunxd3 Aug 29, 2025
618c23e
stop using `invokelatest`
sunxd3 Sep 1, 2025
7eb5224
fix test errors
sunxd3 Sep 1, 2025
f441205
Merge branch 'main' into sunxd/auto_marginalization
sunxd3 Sep 1, 2025
7b8cedd
fix performance by moving more computation to the construction
sunxd3 Sep 1, 2025
2d66d1c
Update bugsmodel.jl
sunxd3 Sep 2, 2025
5f00a14
add experiment package
sunxd3 Sep 14, 2025
4ea7726
example
sunxd3 Sep 23, 2025
c52e5e0
remove experiments
sunxd3 Sep 23, 2025
7a38695
chore: empty commit [skip ci]
sunxd3 Sep 23, 2025
1d229ca
chore: revert example formatting changes
sunxd3 Sep 23, 2025
f7a044f
move the auto-marg doc into JuliaBUGS docs folder
sunxd3 Sep 23, 2025
6d2bc02
rename the auto marg doc
sunxd3 Sep 23, 2025
f6002b3
Merge branch 'main' into sunxd/auto_marginalization
sunxd3 Sep 23, 2025
d8874e4
fix: stabilize auto-marginalization caches and tempering
sunxd3 Sep 23, 2025
cf180e3
add experiment code
sunxd3 Sep 24, 2025
f62ff91
check in all the experiment code
sunxd3 Sep 30, 2025
2d7a51b
update scripts; remove results
sunxd3 Sep 30, 2025
aec0159
update plan
sunxd3 Sep 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
412 changes: 412 additions & 0 deletions JuliaBUGS/docs/src/auto_marginalization.md

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions JuliaBUGS/experiments/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
102 changes: 102 additions & 0 deletions JuliaBUGS/experiments/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Experiments Workspace

- Project path: `JuliaBUGS/experiments`
- Scripts live in `JuliaBUGS/experiments/scripts/`
- Shared helpers are in `JuliaBUGS/experiments/utils.jl`

## Running scripts

Always pass the experiments project to Julia so the correct environment loads:

```
julia --project=JuliaBUGS/experiments scripts/hmm_marginal_logp.jl
```

Most scripts accept environment variables to tweak configurations. The simple
HMM example supports:

- `AM_SEED` (default `1`) – RNG seed used for synthetic data.
- `AM_T` (default `50`) – length of the simulated sequence.

Batch sweep (`hmm_correctness_sweep.jl`):

- `AM_SWEEP_SEEDS` (default `1`) – comma-separated list of seeds.
- `AM_SWEEP_K` (default `2,4`) – comma-separated list of state counts.
- `AM_SWEEP_T` (default `50,200`) – comma-separated list of sequence lengths.

GMM sweep (`gmm_correctness_sweep.jl`):

- `AG_SWEEP_SEEDS` (default `1`) – comma-separated list of seeds.
- `AG_SWEEP_K` (default `2,4`) – comma-separated list of mixture counts.
- `AG_SWEEP_N` (default `100,1000`) – comma-separated list of observation counts.

HMM gradient check (`hmm_gradient_check.jl`):

- `AGC_SEED` (default `1`) – RNG seed for synthetic data.
- `AGC_K` (default `2`) – number of HMM states.
- `AGC_T` (default `50`) – length of the simulated sequence.
- `AGC_EPS` (default `1e-5`) – step size for central finite differences.
- `AGC_VERBOSE` (default `0`) – set `1` to print per-θ details.
- `AGC_SWEEP_SEEDS` – comma-separated list of seeds (overrides `AGC_SEED`).
- `AGC_SWEEP_K` – comma-separated list of state counts (overrides `AGC_K`).
- `AGC_SWEEP_T` – comma-separated list of sequence lengths (overrides `AGC_T`).

GMM gradient check (`gmm_gradient_check.jl`):

- `AGG_SEED` (default `1`) – RNG seed for synthetic data.
- `AGG_K` (default `2`) – number of mixture components.
- `AGG_N` (default `200`) – number of observations.
- `AGG_EPS` (default `1e-5`) – step size for central finite differences.
- `AGG_VERBOSE` (default `0`) – set `1` to print per-θ details.
- `AGG_SWEEP_SEEDS` – comma-separated list of seeds (overrides `AGG_SEED`).
- `AGG_SWEEP_K` – comma-separated list of mixture counts (overrides `AGG_K`).
- `AGG_SWEEP_N` – comma-separated list of observation counts (overrides `AGG_N`).

HMM scaling benchmark (`hmm_scaling_bench.jl`):

- `AS_SEED` (default `1`) – RNG seed. Use `AS_SWEEP_SEEDS` for a list.
- `AS_K` (default `2,4`) – number of states. Use `AS_SWEEP_K` for a list.
- `AS_T` (default `50,200`) – sequence length. Use `AS_SWEEP_T` for a list.
- `AS_TRIALS` (default `5`) – number of timing repetitions per case.

Notes:
- The benchmark enforces the interleaved (time-first) order to reflect optimal scaling.

FHMM order comparison (`fhmm_order_comparison.jl`):

- `AFH_SEED` (default `1`) – RNG seed.
- `AFH_C` (default `2`) – number of chains.
- `AFH_K` (default `4`) – number of states per chain.
- `AFH_T` (default `100`) – length of the sequence.
- `AFH_TRIALS` (default `10`) – timing samples per order.
- `AFH_MODE` (default `frontier`) – `frontier` or `timed`. Interleaved is always timed; the bad order is timed only if its proxy cost ≤ `AFH_COST_THRESH` (or when `AFH_MODE=timed`).
- `AFH_COST_THRESH` (default `1e8`) – threshold on the proxy Σ K^width (compared in log-space) to avoid intractable timings.
- `AFH_ORDERS` (optional) – comma‑separated list of orders to run. Accepted values: `interleaved`, `states_then_y`. Default runs both in that order. Example: `AFH_ORDERS=interleaved` or `AFH_ORDERS=states_then_y`.

Outputs CSV lines with columns
`order,max_frontier,mean_frontier,sum_frontier,log_cost_proxy,min_time_sec,logp`.
Two consistent orders are evaluated:
- `interleaved` (time‑first, tractable)
- `states_then_y` (all z’s, then all y’s; typically intractable for moderate T)

Output: CSV lines with columns
`seed,K,T,trials,min_time_sec,logp,max_frontier,mean_frontier,sum_frontier`.

HMT order comparison (`hmt_order_comparison.jl`):

- `AHMT_SEED` (default `1`) – RNG seed.
- `AHMT_B` (default `2`) – branching factor.
- `AHMT_DEPTH` (default `8`) – tree depth.
- `AHMT_K` (default `4`) – number of states per node.
- `AHMT_TRIALS` (default `10`) – timing samples when timing is enabled.
- `AHMT_MODE` (default `frontier`) – `frontier` (frontier/proxy only), `timed` (time all listed orders), or `dfs` (time DFS only). We avoid timing BFS when its proxy is enormous.

Reproduce figures (PDFs in `experiments/figures`):

```
julia --project=JuliaBUGS/experiments experiments/plotting/make_figures.jl
```

Tables included in the draft live in `experiments/tables/` and are generated from CSV outputs under `experiments/results/`.

When adding new scripts, document their environment variables near the top of the file and list them here for quick reference.
109 changes: 109 additions & 0 deletions JuliaBUGS/experiments/experiment_plan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Experiment Plan: Auto-Marginalization

Experiments validating automatic marginalization of discrete latent variables in JuliaBUGS.

## 1. Correctness

Validates marginalized log-probability against analytical references.

```bash
# HMM
AM_SWEEP_SEEDS=1,2,3 AM_SWEEP_K=2,4,8,16 AM_SWEEP_T=50,100,200,400 \
julia --project=JuliaBUGS/experiments scripts/hmm_correctness_sweep.jl

# GMM
AG_SWEEP_SEEDS=1,2,3 AG_SWEEP_K=2,4,8 AG_SWEEP_N=100,500,1000,5000 \
julia --project=JuliaBUGS/experiments scripts/gmm_correctness_sweep.jl

# HDP-HMM (sticky, κ=0)
AHDPC_SEEDS=1,2 AHDPC_K=5,10,20 AHDPC_T=50,100,200,400 AHDPC_KAPPA=0.0 \
julia --project=JuliaBUGS/experiments scripts/hdphmm_correctness.jl

# HDP-HMM (sticky, κ=5)
AHDPC_SEEDS=1,2 AHDPC_K=5,10,20 AHDPC_T=50,100,200,400 AHDPC_KAPPA=5.0 \
julia --project=JuliaBUGS/experiments scripts/hdphmm_correctness.jl
```

## 2. Gradients

Validates automatic differentiation against finite differences.

```bash
# HMM
AGC_SWEEP_SEEDS=1,2,3 AGC_SWEEP_K=2,4,8 AGC_SWEEP_T=50,100,200 \
julia --project=JuliaBUGS/experiments scripts/hmm_gradient_check.jl

# GMM
AGG_SWEEP_SEEDS=1,2,3 AGG_SWEEP_K=2,4,8 AGG_SWEEP_N=200,500,1000 \
julia --project=JuliaBUGS/experiments scripts/gmm_gradient_check.jl

# HDP-HMM (sticky, κ=0)
AHDPG_SWEEP_SEEDS=1,2 AHDPG_SWEEP_K=5,10,20 AHDPG_SWEEP_T=100,200 AHDPG_KAPPA=0.0 \
julia --project=JuliaBUGS/experiments scripts/hdphmm_gradient_check.jl

# HDP-HMM (sticky, κ=5)
AHDPG_SWEEP_SEEDS=1,2 AHDPG_SWEEP_K=5,10,20 AHDPG_SWEEP_T=100,200 AHDPG_KAPPA=5.0 \
julia --project=JuliaBUGS/experiments scripts/hdphmm_gradient_check.jl
```

## 3. Scaling

Benchmarks runtime vs problem size to verify O(T·K²) complexity beyond overhead regime.

```bash
# HMM - Push into asymptotic regime
AS_SWEEP_K=8,16,32,64,128,256,512,1024,2048 \
AS_SWEEP_T=50,100,200,400,800,1600,3200,6400 \
AS_TRIALS=10 \
julia --project=JuliaBUGS/experiments scripts/hmm_scaling_bench.jl
```

## 4. Variable Ordering: FHMM

Compares elimination orders (interleaved, states_then_y, min_fill, min_degree).

```bash
# Small configs with timing
AFH_C=2 AFH_K=2 AFH_T=5 AFH_MODE=timed AFH_ORDERS=interleaved,states_then_y \
julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl
AFH_C=2 AFH_K=4 AFH_T=10 AFH_MODE=timed AFH_ORDERS=interleaved,states_then_y \
julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl

# Larger configs (frontier only)
AFH_C=2 AFH_K=4 AFH_T=50 AFH_MODE=frontier AFH_ORDERS=interleaved,states_then_y,min_fill,min_degree \
julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl
AFH_C=3 AFH_K=4 AFH_T=50 AFH_MODE=frontier AFH_ORDERS=interleaved,states_then_y,min_fill,min_degree \
julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl
AFH_C=4 AFH_K=4 AFH_T=50 AFH_MODE=frontier AFH_ORDERS=interleaved,states_then_y,min_fill,min_degree \
julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl
```

## 5. Variable Ordering: HMT

Compares tree traversal orders (dfs, bfs, random_dfs, min_fill, min_degree).

```bash
# Varying depth
AHMT_B=2 AHMT_K=4 AHMT_DEPTH=4 AHMT_MODE=frontier \
julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl
AHMT_B=2 AHMT_K=4 AHMT_DEPTH=6 AHMT_MODE=frontier \
julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl
AHMT_B=2 AHMT_K=4 AHMT_DEPTH=8 AHMT_MODE=frontier \
julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl
AHMT_B=2 AHMT_K=4 AHMT_DEPTH=10 AHMT_MODE=frontier \
julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl

# Varying branching and states
AHMT_B=2 AHMT_K=2 AHMT_DEPTH=6 AHMT_MODE=frontier \
julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl
AHMT_B=3 AHMT_K=2 AHMT_DEPTH=6 AHMT_MODE=frontier \
julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl
```

## Notes

- **Complexity**: With good (interleaved) ordering, HMM marginalization achieves **O(T·K²)** complexity (linear in T). Bad orderings (e.g., states-first) explode to O(K^T) by enumerating all state sequences. Frontier width measures active discrete variables: good orders keep it ≈ O(1), bad orders reach O(T).
- **Heuristics**: Min-fill and min-degree with randomized tie-breaking (3 restarts) find good orders for arbitrary graphical models.
- **HDP-HMM**: Both correctness and gradient scripts use the sticky HDP-HMM formulation with kappa (κ) parameter. Set AHDPC_KAPPA/AHDPG_KAPPA to control sticky self-transition bias. κ=0 is standard HDP-HMM, κ>0 adds self-transition preference.
- **Scaling**: Sweep includes K up to 2048 states and T up to 6400 time steps to clearly show asymptotic O(T·K²) complexity beyond JIT/overhead regime. On log-log plot (time vs T, fixed K), expect slope=1 at large T. Parallel lines vertically shifted by ≈2·log(K) demonstrate tractability even with huge state spaces.
- **Output**: All scripts write CSV to stdout. Redirect as needed: `> results/output.csv`
Loading
Loading