diff --git a/JuliaBUGS/docs/src/auto_marginalization.md b/JuliaBUGS/docs/src/auto_marginalization.md new file mode 100644 index 000000000..8d6c33cb4 --- /dev/null +++ b/JuliaBUGS/docs/src/auto_marginalization.md @@ -0,0 +1,412 @@ +# Automatic Marginalization in Mixed Discrete-Continuous Models + +## The Problem + +Modern gradient-based inference methods require differentiable log-densities: + +- HMC, NUTS, variational inference need smooth objectives +- Discrete variables aren't differentiable +- Solution: marginalize out discrete variables exactly + +**Challenge**: How to do this efficiently? + + +## Example Model + +Mixed BN: discrete $X,C,Z$; continuous $A,B$; observed $D$. + +```mermaid +graph LR + X((X)):::discrete --> A((A)):::continuous + A --> B((B)):::continuous + A --> C((C)):::discrete + B --> D((D)):::observed + C --> D + Z((Z)):::discrete --> D + + classDef discrete fill:#FFF4E6,stroke:#D9480F,stroke-width:2px,stroke-dasharray:5 + classDef continuous fill:#E7F5FF,stroke:#1C7ED6,stroke-width:2px + classDef observed fill:#E6FCF5,stroke:#2B8A3E,stroke-width:2px +``` + + +## BUGS Program + +```r +model { + # Discrete priors + X ~ dcat(piX[]) # piX is length-|X| + Z ~ dcat(piZ[]) # piZ is length-|Z| + + # Conditionals + A ~ dnorm(muX[X], 1/pow(sigmaA,2)) + B ~ dnorm(A, 1/pow(sigmaB,2)) + + # Logistic gate for C + logit(pC) <- alpha0 + alpha1 * A + pCvec[1] <- 1 - pC + pCvec[2] <- pC + C ~ dcat(pCvec[]) + + # Likelihood + D ~ dnorm(B + deltaC[C] + deltaZ[Z], 1/pow(sigmaD,2)) +} +``` + + +## Joint Probability + +Following topological order $z, x, a, b, c, d$: + +$$ +\begin{aligned} +\log p(z, x, a, b, c, d) &= \log p(z) + \log p(x) + \log p(a \mid x) \\ +&\quad + \log p(b \mid a) + \log p(c \mid a) \\ +&\quad + \log p(d \mid b, c, z) +\end{aligned} +$$ + + +## Marginalization Target + +For gradient-based inference, marginalize out discrete variables: + +$$ +\begin{aligned} +\log p(a, b, d) = \log &\sum_{z \in \mathcal{Z}} \sum_{x \in \mathcal{X}} \Big[ p(z) \cdot p(x) \cdot p(a \mid x) \cdot p(b \mid a) \\ +&\cdot \sum_{c \in \mathcal{C}} p(c \mid a) \cdot p(d \mid b, c, z) \Big] +\end{aligned} +$$ + +following the same topological order in the computation. + + +## Naive Enumeration + +Full expansion reveals massive redundancy: + +$$ +\begin{aligned} +&p(z{=}1) \cdot p(x{=}1) \cdot p(a|x{=}1) \cdot p(b|a) \cdot p(c{=}1|a) \cdot p(d|b,c{=}1,z{=}1) \\ +&p(z{=}1) \cdot p(x{=}1) \cdot p(a|x{=}1) \cdot p(b|a) \cdot p(c{=}2|a) \cdot p(d|b,c{=}2,z{=}1) \\ +&p(z{=}1) \cdot p(x{=}2) \cdot p(a|x{=}2) \cdot p(b|a) \cdot p(c{=}1|a) \cdot p(d|b,c{=}1,z{=}1) \\ +&p(z{=}1) \cdot p(x{=}2) \cdot p(a|x{=}2) \cdot p(b|a) \cdot p(c{=}2|a) \cdot p(d|b,c{=}2,z{=}1) \\ +&\vdots \text{ (4 more terms for } z{=}2) +\end{aligned} +$$ + + +## Identifying Redundancy + +Looking at the 8 terms, each subexpression appears multiple times: + +- $p(d|b,c{=}1,z{=}1)$ computed twice (for $x{=}1$ and $x{=}2$) +- $p(d|b,c{=}2,z{=}1)$ computed twice +- $p(d|b,c{=}1,z{=}2)$ computed twice +- $p(d|b,c{=}2,z{=}2)$ computed twice + +Larger suffixes also repeat: + +- $p(b|a) \cdot p(c{=}1|a) \cdot p(d|b,c{=}1,z{=}1)$ appears for both $x{=}1$ and $x{=}2$ +- ... + + +## The Solution: Strategic Caching + +Cache intermediate results to avoid recomputation. + +**Key challenge**: What to use as cache key? + +**Naive approach**: Use all discrete variables seen so far +- Doesn't work, because it is equivalent to enumeration + +**Solution**: Use the minimal set of discrete latent that still affect the future + + +## The Frontier Concept + +$K_k$ is the **minimal** already-visited set whose values appear in any **unvisited** factor; equivalently, the separator induced by your order at step $k$. + +This separates past computations from future ones. + +Traditional Bayesian network inference exploits conditional independence. We achieve the same effect through runtime caching based on the frontier. + + +## Frontier Evolution + +| Position | Discrete frontier $K_k \cap Q$ | Why it remains needed | +| -------- | ------------------------------ | --------------------------------------------------------------------- | +| Start | $\{\}$ | Nothing visited yet. | +| After $z$ | $\{z\}$ | $z \rightarrow d$ is still ahead; future depends on $z$. | +| After $x$ | $\{z, x\}$ | $z \rightarrow d$ remains; $x$ still influences future through $a$. | +| After $c$ | $\{z, c\}$ | Both $z$ and $c$ feed $d$ ahead; $x$ no longer affects future. | + + +## Impact of Order + +Moving $Z$ late ($x,a,b,c,z,d$) delays introducing $z$ into the frontier, shrinking early cache keys. + +Placing $B$ before branching on $C$ avoids recomputing $p(b\mid a)$ for each $(c,z)$. + + +## Algorithm Overview + +Evaluate in a topological order. Enumerate at discrete sites. **Memoize** each suffix with key $(k, K_k\cap Q)$. + + +## Precomputing Frontier Keys + +One backward pass computes all frontiers. + +```text +P ← ∅ // parents of future nodes +for k = n down to 1: // process v_n, …, v_1 + K_k ← P ∩ {v_1,…,v_k} // frontier after v_k + P ← P ∪ parents(v_k) +``` +Cache key at discrete sites = $K_k \cap \{\text{discrete variables}\}$ + +## Example with order $z, x, a, b, c, d$ + +Edges: $x \to a$, $a \to b$, $a \to c$, $b \to d$, $c \to d$, $z \to d$ + +| Position | P (parents-of-future) | Frontier $K_k$ | Cache key (discrete only) | +|----------|----------------------|----------------|---------------------------| +| After $d$ | $\{\}$ | $\{\}$ | - | +| After $c$ | $\{b, c, z\}$ | $\{z, b, c\}$ | $\{z, c\}$ | +| After $b$ | $\{z, a, b, c\}$ | $\{z, a, b\}$ | (no cache) | +| After $a$ | $\{z, a, b, c\}$ | $\{z, a\}$ | (no cache) | +| After $x$ | $\{z, a, b, c, x\}$ | $\{z, x\}$ | $\{z, x\}$ | +| After $z$ | $\{z, a, b, c, x\}$ | $\{z\}$ | $\{z\}$ | + + +## Factor Graph Representation + +Marginalization can also be formulated using message passing on factor graphs. + +```mermaid +graph TD + X((X)):::discrete + Z((Z)):::discrete + C((C)):::discrete + A{A=a}:::clamped + B{B=b}:::clamped + D((D=d)):::observed + + fX[p_x]:::factor + fZ[p_z]:::factor + fA[p_ax]:::factor + fB[p_ba]:::factor + fC[p_ca]:::factor + fD[p_dbcz]:::factor + + A --- fA + A --- fB + A --- fC + B --- fB + B --- fD + C --- fC + C --- fD + D --- fD + X --- fX + X --- fA + Z --- fZ + Z --- fD + + classDef discrete fill:#FFF4E6,stroke:#D9480F,stroke-width:2px,stroke-dasharray:5 + classDef clamped fill:#B0D4F1,stroke:#1C7ED6,stroke-width:3px + classDef observed fill:#C3F0DC,stroke:#2B8A3E,stroke-width:3px + classDef factor fill:#F1F3F5,stroke:#495057,stroke-width:2px +``` + + +## Message Passing with Elimination Order: $X$, then $(C,Z)$ + +We choose $B$ as our target node where all messages flow toward. + +Why this order? After conditioning on $(A,B,D)$, the graph forms a tree where: + +- $X$ is isolated upstream of $A$ +- $C$ and $Z$ converge at $D$ but are independent given clamped values + +To compute the belief at $B$, we collect all information flowing toward it: + +Step 1: Upstream message (from $X$ through $A$ to $B$) +$$ +\phi_X(a) = \sum_{x} p(x)\,p(a|x) +$$ +This eliminates $X$ and will flow to $B$ via the factor $p(b|a)$. + +After eliminating $X$, we have: + +```mermaid +graph TD + Z((Z)):::discrete + C((C)):::discrete + A{A=a}:::clamped + B{B=b}:::clamped + D((D=d)):::observed + + phiX[φ_X]:::newfactor + fZ[p_z]:::factor + fB[p_ba]:::factor + fC[p_ca]:::factor + fD[p_dbcz]:::factor + + A --- phiX + A --- fB + A --- fC + B --- fB + B --- fD + C --- fC + C --- fD + Z --- fZ + Z --- fD + D --- fD + + classDef discrete fill:#FFF4E6,stroke:#D9480F,stroke-width:2px,stroke-dasharray:5 + classDef clamped fill:#B0D4F1,stroke:#1C7ED6,stroke-width:3px + classDef observed fill:#C3F0DC,stroke:#2B8A3E,stroke-width:3px + classDef factor fill:#D8D8D8,stroke:#495057,stroke-width:2px + classDef newfactor fill:#FFE3E3,stroke:#C92A2A,stroke-width:2px +``` + +Step 2: Downstream message (from $D$ back to $B$) + +To compute what $D$ tells us about $B$, we eliminate $C$ and $Z$: + +- $C$ brings information $p(c|a)$ from its connection to $A$ +- $Z$ brings its prior $p(z)$ +- Both connect to $B$ through factor $p(d|b,c,z)$ + +The combined message to $B$: +$$ +\phi_{CZ}(b) = \sum_{c,z} p(c\mid a)\,p(z)\,p(d\mid b,c,z) +$$ + + +## Final Marginal Likelihood + +At node $B$, we combine: + +- Upstream message: $\phi_X(a)$ through factor $p(b|a)$ +- Local factor: $p(b|a)$ +- Downstream message: $\phi_{CZ}(b)$ from the $D$ branch + +The unnormalized belief at $B$: + +$$ +\boxed{\; +p(a,b,d) = \phi_X(a) \cdot p(b\mid a) \cdot \phi_{CZ}(b) +\;} +$$ + +Expanding each message: +$$ +\begin{aligned} +p(a,b,d) &= \underbrace{\left[\sum_x p(x)p(a|x)\right]}_{\phi_X(a)} \cdot p(b|a) \cdot \underbrace{\left[\sum_{c,z} p(c|a)p(z)p(d|b,c,z)\right]}_{\phi_{CZ}(b)}\\ +&= \left[\sum_x p(x)p(a|x)\right] \cdot p(b|a) \cdot \left[\sum_c p(c|a) \sum_z p(z)p(d|b,c,z)\right] +\end{aligned} +$$ + + +## Which Topological Orders Yield This Factorization? + +$X, A, B, C, Z, D$ on the Bayesian Network + +Produce: +$$ +\left[\sum_x p(x)p(a|x)\right] \cdot p(b|a) \cdot \left[\sum_c p(c|a) \sum_z p(z)p(d|b,c,z)\right] +$$ + + +## Frontier Evolution for Order $X, A, B, C, Z, D$ + +| Position | Discrete frontier | What happens | +|----------|------------------|--------------| +| Start | $\{\}$ | Nothing processed yet | +| After $X$ | $\{\}$ | $X$ eliminated immediately | +| After $A$ | $\{\}$ | $A$ is continuous | +| After $B$ | $\{\}$ | $B$ is continuous | +| After $C$ | $\{C\}$ | $C$ enters frontier, needed for $D$ | +| After $Z$ | $\{C,Z\}$ | Both needed for $D$ | +| After $D$ | $\{\}$ | Everything processed | + + +## Messages ARE the DP Cache Entries + +**Message passing computes the same values as DP enumeration** + +In DP enumeration with caching: + +- We cache partial sums with key = discrete frontier +- Frontier tells us which discrete variables affect future computation + +In message passing: + +- Messages ARE these cached partial sums +- Message indexing matches the frontier + +**Example from our graph:** + +| Computation | What it is | Frontier (cache key) | Reuse pattern | +|------------|------------|---------------------|---------------| +| $\sum_x p(x)p(a \mid x)$ | Message eliminating $X$ | $\{\}$ | Computed once | +| $\sum_z p(z)p(d \mid b,c,z)$ | Partial message eliminating $Z$ | $\{C\}$ | One per $c$ value | + +The frontier-based cache key ensures we recompute only when necessary + + +## Two Perspectives on the Same Algorithm + +### Top-down with memoization (our frontier caching): + +- Start from the goal: compute $p(a,b,d)$ +- Recursively evaluate the DAG following topological order +- **Memoize** (cache) partial sums at discrete variables +- Cache key = discrete frontier (what future computation needs) +- Like recursive Fibonacci with memoization + +### Bottom-up tabulation (factor graph message passing): + +- Start from the leaves of the tree +- Build up messages from smaller to larger scopes +- **Precompute** all messages in optimal order +- Store messages indexed by their variables +- Like iterative Fibonacci filling a table + + +## Analogy: Two Ways to Compute Fibonacci + +**Top-down with memoization:** +```julia +memo = Dict() +function fib_memo(n) + n ∈ keys(memo) && return memo[n] # check cache + n ≤ 1 && return n + memo[n] = fib_memo(n-1) + fib_memo(n-2) # cache result + return memo[n] +end +``` + +**Bottom-up tabulation:** +```julia +function fib_table(n) + dp = zeros(n+1) + dp[1] = 0; dp[2] = 1 + for i in 3:n+1 + dp[i] = dp[i-1] + dp[i-2] # fill table + end + return dp[n+1] +end +``` + +--- + +Ultimately, the caching algorithm gives factor-graph quality marginalization with the DAG representation, which has some benefits: + +- DAG + topological order -> straight-line program +- deterministic node gives easy-to-reason computation reuse +- It's a natural compilation target, the evaluator is straightforward to write and understand + +The contribution is to show what the cache key is and how to precompute \ No newline at end of file diff --git a/JuliaBUGS/experiments/Project.toml b/JuliaBUGS/experiments/Project.toml new file mode 100644 index 000000000..b48f85666 --- /dev/null +++ b/JuliaBUGS/experiments/Project.toml @@ -0,0 +1,24 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" +LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" diff --git a/JuliaBUGS/experiments/README.md b/JuliaBUGS/experiments/README.md new file mode 100644 index 000000000..54b00cff6 --- /dev/null +++ b/JuliaBUGS/experiments/README.md @@ -0,0 +1,102 @@ +# Experiments Workspace + +- Project path: `JuliaBUGS/experiments` +- Scripts live in `JuliaBUGS/experiments/scripts/` +- Shared helpers are in `JuliaBUGS/experiments/utils.jl` + +## Running scripts + +Always pass the experiments project to Julia so the correct environment loads: + +``` +julia --project=JuliaBUGS/experiments scripts/hmm_marginal_logp.jl +``` + +Most scripts accept environment variables to tweak configurations. The simple +HMM example supports: + +- `AM_SEED` (default `1`) – RNG seed used for synthetic data. +- `AM_T` (default `50`) – length of the simulated sequence. + +Batch sweep (`hmm_correctness_sweep.jl`): + +- `AM_SWEEP_SEEDS` (default `1`) – comma-separated list of seeds. +- `AM_SWEEP_K` (default `2,4`) – comma-separated list of state counts. +- `AM_SWEEP_T` (default `50,200`) – comma-separated list of sequence lengths. + +GMM sweep (`gmm_correctness_sweep.jl`): + +- `AG_SWEEP_SEEDS` (default `1`) – comma-separated list of seeds. +- `AG_SWEEP_K` (default `2,4`) – comma-separated list of mixture counts. +- `AG_SWEEP_N` (default `100,1000`) – comma-separated list of observation counts. + +HMM gradient check (`hmm_gradient_check.jl`): + +- `AGC_SEED` (default `1`) – RNG seed for synthetic data. +- `AGC_K` (default `2`) – number of HMM states. +- `AGC_T` (default `50`) – length of the simulated sequence. +- `AGC_EPS` (default `1e-5`) – step size for central finite differences. +- `AGC_VERBOSE` (default `0`) – set `1` to print per-θ details. +- `AGC_SWEEP_SEEDS` – comma-separated list of seeds (overrides `AGC_SEED`). +- `AGC_SWEEP_K` – comma-separated list of state counts (overrides `AGC_K`). +- `AGC_SWEEP_T` – comma-separated list of sequence lengths (overrides `AGC_T`). + +GMM gradient check (`gmm_gradient_check.jl`): + +- `AGG_SEED` (default `1`) – RNG seed for synthetic data. +- `AGG_K` (default `2`) – number of mixture components. +- `AGG_N` (default `200`) – number of observations. +- `AGG_EPS` (default `1e-5`) – step size for central finite differences. +- `AGG_VERBOSE` (default `0`) – set `1` to print per-θ details. +- `AGG_SWEEP_SEEDS` – comma-separated list of seeds (overrides `AGG_SEED`). +- `AGG_SWEEP_K` – comma-separated list of mixture counts (overrides `AGG_K`). +- `AGG_SWEEP_N` – comma-separated list of observation counts (overrides `AGG_N`). + +HMM scaling benchmark (`hmm_scaling_bench.jl`): + +- `AS_SEED` (default `1`) – RNG seed. Use `AS_SWEEP_SEEDS` for a list. +- `AS_K` (default `2,4`) – number of states. Use `AS_SWEEP_K` for a list. +- `AS_T` (default `50,200`) – sequence length. Use `AS_SWEEP_T` for a list. +- `AS_TRIALS` (default `5`) – number of timing repetitions per case. + +Notes: +- The benchmark enforces the interleaved (time-first) order to reflect optimal scaling. + +FHMM order comparison (`fhmm_order_comparison.jl`): + +- `AFH_SEED` (default `1`) – RNG seed. +- `AFH_C` (default `2`) – number of chains. +- `AFH_K` (default `4`) – number of states per chain. +- `AFH_T` (default `100`) – length of the sequence. +- `AFH_TRIALS` (default `10`) – timing samples per order. +- `AFH_MODE` (default `frontier`) – `frontier` or `timed`. Interleaved is always timed; the bad order is timed only if its proxy cost ≤ `AFH_COST_THRESH` (or when `AFH_MODE=timed`). +- `AFH_COST_THRESH` (default `1e8`) – threshold on the proxy Σ K^width (compared in log-space) to avoid intractable timings. +- `AFH_ORDERS` (optional) – comma‑separated list of orders to run. Accepted values: `interleaved`, `states_then_y`. Default runs both in that order. Example: `AFH_ORDERS=interleaved` or `AFH_ORDERS=states_then_y`. + +Outputs CSV lines with columns +`order,max_frontier,mean_frontier,sum_frontier,log_cost_proxy,min_time_sec,logp`. +Two consistent orders are evaluated: +- `interleaved` (time‑first, tractable) +- `states_then_y` (all z’s, then all y’s; typically intractable for moderate T) + +Output: CSV lines with columns +`seed,K,T,trials,min_time_sec,logp,max_frontier,mean_frontier,sum_frontier`. + +HMT order comparison (`hmt_order_comparison.jl`): + +- `AHMT_SEED` (default `1`) – RNG seed. +- `AHMT_B` (default `2`) – branching factor. +- `AHMT_DEPTH` (default `8`) – tree depth. +- `AHMT_K` (default `4`) – number of states per node. +- `AHMT_TRIALS` (default `10`) – timing samples when timing is enabled. +- `AHMT_MODE` (default `frontier`) – `frontier` (frontier/proxy only), `timed` (time all listed orders), or `dfs` (time DFS only). We avoid timing BFS when its proxy is enormous. + +Reproduce figures (PDFs in `experiments/figures`): + +``` +julia --project=JuliaBUGS/experiments experiments/plotting/make_figures.jl +``` + +Tables included in the draft live in `experiments/tables/` and are generated from CSV outputs under `experiments/results/`. + +When adding new scripts, document their environment variables near the top of the file and list them here for quick reference. diff --git a/JuliaBUGS/experiments/experiment_plan.md b/JuliaBUGS/experiments/experiment_plan.md new file mode 100644 index 000000000..20822625b --- /dev/null +++ b/JuliaBUGS/experiments/experiment_plan.md @@ -0,0 +1,109 @@ +# Experiment Plan: Auto-Marginalization + +Experiments validating automatic marginalization of discrete latent variables in JuliaBUGS. + +## 1. Correctness + +Validates marginalized log-probability against analytical references. + +```bash +# HMM +AM_SWEEP_SEEDS=1,2,3 AM_SWEEP_K=2,4,8,16 AM_SWEEP_T=50,100,200,400 \ + julia --project=JuliaBUGS/experiments scripts/hmm_correctness_sweep.jl + +# GMM +AG_SWEEP_SEEDS=1,2,3 AG_SWEEP_K=2,4,8 AG_SWEEP_N=100,500,1000,5000 \ + julia --project=JuliaBUGS/experiments scripts/gmm_correctness_sweep.jl + +# HDP-HMM (sticky, κ=0) +AHDPC_SEEDS=1,2 AHDPC_K=5,10,20 AHDPC_T=50,100,200,400 AHDPC_KAPPA=0.0 \ + julia --project=JuliaBUGS/experiments scripts/hdphmm_correctness.jl + +# HDP-HMM (sticky, κ=5) +AHDPC_SEEDS=1,2 AHDPC_K=5,10,20 AHDPC_T=50,100,200,400 AHDPC_KAPPA=5.0 \ + julia --project=JuliaBUGS/experiments scripts/hdphmm_correctness.jl +``` + +## 2. Gradients + +Validates automatic differentiation against finite differences. + +```bash +# HMM +AGC_SWEEP_SEEDS=1,2,3 AGC_SWEEP_K=2,4,8 AGC_SWEEP_T=50,100,200 \ + julia --project=JuliaBUGS/experiments scripts/hmm_gradient_check.jl + +# GMM +AGG_SWEEP_SEEDS=1,2,3 AGG_SWEEP_K=2,4,8 AGG_SWEEP_N=200,500,1000 \ + julia --project=JuliaBUGS/experiments scripts/gmm_gradient_check.jl + +# HDP-HMM (sticky, κ=0) +AHDPG_SWEEP_SEEDS=1,2 AHDPG_SWEEP_K=5,10,20 AHDPG_SWEEP_T=100,200 AHDPG_KAPPA=0.0 \ + julia --project=JuliaBUGS/experiments scripts/hdphmm_gradient_check.jl + +# HDP-HMM (sticky, κ=5) +AHDPG_SWEEP_SEEDS=1,2 AHDPG_SWEEP_K=5,10,20 AHDPG_SWEEP_T=100,200 AHDPG_KAPPA=5.0 \ + julia --project=JuliaBUGS/experiments scripts/hdphmm_gradient_check.jl +``` + +## 3. Scaling + +Benchmarks runtime vs problem size to verify O(T·K²) complexity beyond overhead regime. + +```bash +# HMM - Push into asymptotic regime +AS_SWEEP_K=8,16,32,64,128,256,512,1024,2048 \ +AS_SWEEP_T=50,100,200,400,800,1600,3200,6400 \ +AS_TRIALS=10 \ + julia --project=JuliaBUGS/experiments scripts/hmm_scaling_bench.jl +``` + +## 4. Variable Ordering: FHMM + +Compares elimination orders (interleaved, states_then_y, min_fill, min_degree). + +```bash +# Small configs with timing +AFH_C=2 AFH_K=2 AFH_T=5 AFH_MODE=timed AFH_ORDERS=interleaved,states_then_y \ + julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl +AFH_C=2 AFH_K=4 AFH_T=10 AFH_MODE=timed AFH_ORDERS=interleaved,states_then_y \ + julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl + +# Larger configs (frontier only) +AFH_C=2 AFH_K=4 AFH_T=50 AFH_MODE=frontier AFH_ORDERS=interleaved,states_then_y,min_fill,min_degree \ + julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl +AFH_C=3 AFH_K=4 AFH_T=50 AFH_MODE=frontier AFH_ORDERS=interleaved,states_then_y,min_fill,min_degree \ + julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl +AFH_C=4 AFH_K=4 AFH_T=50 AFH_MODE=frontier AFH_ORDERS=interleaved,states_then_y,min_fill,min_degree \ + julia --project=JuliaBUGS/experiments scripts/fhmm_order_comparison.jl +``` + +## 5. Variable Ordering: HMT + +Compares tree traversal orders (dfs, bfs, random_dfs, min_fill, min_degree). + +```bash +# Varying depth +AHMT_B=2 AHMT_K=4 AHMT_DEPTH=4 AHMT_MODE=frontier \ + julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl +AHMT_B=2 AHMT_K=4 AHMT_DEPTH=6 AHMT_MODE=frontier \ + julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl +AHMT_B=2 AHMT_K=4 AHMT_DEPTH=8 AHMT_MODE=frontier \ + julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl +AHMT_B=2 AHMT_K=4 AHMT_DEPTH=10 AHMT_MODE=frontier \ + julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl + +# Varying branching and states +AHMT_B=2 AHMT_K=2 AHMT_DEPTH=6 AHMT_MODE=frontier \ + julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl +AHMT_B=3 AHMT_K=2 AHMT_DEPTH=6 AHMT_MODE=frontier \ + julia --project=JuliaBUGS/experiments scripts/hmt_order_comparison.jl +``` + +## Notes + +- **Complexity**: With good (interleaved) ordering, HMM marginalization achieves **O(T·K²)** complexity (linear in T). Bad orderings (e.g., states-first) explode to O(K^T) by enumerating all state sequences. Frontier width measures active discrete variables: good orders keep it ≈ O(1), bad orders reach O(T). +- **Heuristics**: Min-fill and min-degree with randomized tie-breaking (3 restarts) find good orders for arbitrary graphical models. +- **HDP-HMM**: Both correctness and gradient scripts use the sticky HDP-HMM formulation with kappa (κ) parameter. Set AHDPC_KAPPA/AHDPG_KAPPA to control sticky self-transition bias. κ=0 is standard HDP-HMM, κ>0 adds self-transition preference. +- **Scaling**: Sweep includes K up to 2048 states and T up to 6400 time steps to clearly show asymptotic O(T·K²) complexity beyond JIT/overhead regime. On log-log plot (time vs T, fixed K), expect slope=1 at large T. Parallel lines vertically shifted by ≈2·log(K) demonstrate tractability even with huge state spaces. +- **Output**: All scripts write CSV to stdout. Redirect as needed: `> results/output.csv` \ No newline at end of file diff --git a/JuliaBUGS/experiments/scripts/fhmm_order_comparison.jl b/JuliaBUGS/experiments/scripts/fhmm_order_comparison.jl new file mode 100644 index 000000000..d6868b037 --- /dev/null +++ b/JuliaBUGS/experiments/scripts/fhmm_order_comparison.jl @@ -0,0 +1,167 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf +using Statistics +using BenchmarkTools + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs +JuliaBUGS.@bugs_primitive Categorical Normal +using LogDensityProblems + +# Env configuration +C = try parse(Int, get(ENV, "AFH_C", "2")) catch; 2 end +K = try parse(Int, get(ENV, "AFH_K", "4")) catch; 4 end +T = try parse(Int, get(ENV, "AFH_T", "100")) catch; 100 end +seed = try parse(Int, get(ENV, "AFH_SEED", "1")) catch; 1 end +trials = try parse(Int, get(ENV, "AFH_TRIALS", "10")) catch; 10 end +mode = lowercase(get(ENV, "AFH_MODE", "frontier")) # frontier | timed +cost_thresh = try parse(Float64, get(ENV, "AFH_COST_THRESH", "1.0e8")) catch; 1.0e8 end + +function default_fhmm_params(C, K) + init_probs = fill(1.0 / K, C, K) + diag = 0.9 + off = (1.0 - diag) / (K - 1) + transition = Array{Float64}(undef, C, K, K) + for c in 1:C, i in 1:K, j in 1:K + transition[c, i, j] = off + end + for c in 1:C, k in 1:K + transition[c, k, k] = diag + end + mu = Array{Float64}(undef, C, K) + for c in 1:C + mu[c, :] = collect(range(-1.5, 1.5; length=K)) + end + sigma_y = 0.7 + return init_probs, transition, mu, sigma_y +end + +function simulate_fhmm(rng::AbstractRNG, C::Int, K::Int, T::Int; init_probs, transition, mu, sigma_y) + z = Array{Int}(undef, C, T) + y = Vector{Float64}(undef, T) + # Initial states + for c in 1:C + z[c, 1] = rand(rng, Categorical(Vector(view(init_probs, c, 1:K)))) + end + y[1] = rand(rng, Normal(sum(mu[c, z[c, 1]] for c in 1:C), sigma_y)) + # Transitions + for t in 2:T + for c in 1:C + z[c, t] = rand(rng, Categorical(Vector(view(transition, c, z[c, t - 1], 1:K)))) + end + y[t] = rand(rng, Normal(sum(mu[c, z[c, t]] for c in 1:C), sigma_y)) + end + return (; z, y) +end + +function fhmm_model() + @bugs begin + # Initial states + for c in 1:C + z[c, 1] ~ Categorical(init_probs[c, 1:K]) + end + # Emission mean at t=1 via cumulative sum over chains + s[1, 1] = mu[1, z[1, 1]] + for c in 2:C + s[c, 1] = s[c - 1, 1] + mu[c, z[c, 1]] + end + y[1] ~ Normal(s[C, 1], sigma_y) + + # Transitions and emissions for t >= 2 + for t in 2:T + for c in 1:C + z[c, t] ~ Categorical(transition[c, z[c, t - 1], 1:K]) + end + s[1, t] = mu[1, z[1, t]] + for c in 2:C + s[c, t] = s[c - 1, t] + mu[c, z[c, t]] + end + y[t] ~ Normal(s[C, t], sigma_y) + end + end +end + +function frontier_stats_for(model) + gd = model.graph_evaluation_data + order = gd.marginalization_order + keys = gd.minimal_cache_keys + widths = [length(get(keys, idx, Int[])) for idx in order] + if isempty(widths) + return 0, 0.0, 0 + end + return maximum(widths), mean(widths), sum(widths) +end + +# Simulate data +rng = MersenneTwister(seed) +init_probs, transition, mu, sigma_y = default_fhmm_params(C, K) +sim = simulate_fhmm(rng, C, K, T; init_probs=init_probs, transition=transition, mu=mu, sigma_y=sigma_y) + +# Build and compile model +model_def = fhmm_model() +data = ( + C = C, + K = K, + T = T, + y = sim.y, + init_probs = init_probs, + transition = transition, + mu = mu, + sigma_y = sigma_y, +) +model, θ0 = compile_autmarg(model_def, data) + +# Define all available orders +orders = Dict{String,Function}( + "interleaved" => () -> make_model_with_order(model, build_fhmm_interleaved_order(model)), + "states_then_y" => () -> make_model_with_order(model, build_fhmm_states_then_emissions_order(model)), + "min_fill" => () -> make_model_with_order(model, build_min_fill_order(model; rng=MersenneTwister(seed+2), num_restarts=3)), + "min_degree" => () -> make_model_with_order(model, build_min_degree_order(model; rng=MersenneTwister(seed+3), num_restarts=3)), +) + +function parse_orders_env(default_list) + s = lowercase(strip(get(ENV, "AFH_ORDERS", ""))) + if isempty(s) + return default_list + end + names = [strip(x) for x in split(s, ',') if !isempty(strip(x))] + # Filter to known orders; keep input order + selected = [nm for nm in names if haskey(orders, nm)] + return isempty(selected) ? default_list : selected +end + +selected_orders = parse_orders_env(["interleaved", "states_then_y"]) + +@printf "# FHMM order comparison (C=%d, K=%d, T=%d)\n" C K T +@printf "# order,max_frontier,mean_frontier,sum_frontier,log_cost_proxy,min_time_sec,logp\n" +logp_ref = nothing +for name in selected_orders + buildfun = orders[name] + m2 = buildfun() + # Frontier stats and cost proxy (use K as hint) + max_f, mean_f, sum_f, logproxy = frontier_cost_proxy(m2; K_hint=K) + # Timing policy: always time interleaved; only time bad order if proxy ≤ threshold or AFH_MODE=timed + do_time = (name == "interleaved") || (mode == "timed") + tmin = NaN + logp = NaN + # Compare exp(logproxy) with threshold without overflow by comparing logs + if do_time && (name == "interleaved" || logproxy <= log(cost_thresh)) + _ = Base.invokelatest(LogDensityProblems.logdensity, m2, θ0) + _ = Base.invokelatest(LogDensityProblems.logdensity, m2, θ0) + tmin = @belapsed Base.invokelatest(LogDensityProblems.logdensity, $m2, $θ0) samples=trials evals=1 + logp = Base.invokelatest(LogDensityProblems.logdensity, m2, θ0) + if logp_ref === nothing + global logp_ref = logp + end + end + @printf "%s,%d,%.3f,%d,%.3e,%s,%s\n" name max_f mean_f sum_f logproxy ( + isnan(tmin) ? "NA" : @sprintf("%.6e", tmin) + ) ( + isnan(logp) ? "NA" : @sprintf("%.12f", logp) + ) +end diff --git a/JuliaBUGS/experiments/scripts/gmm_correctness_sweep.jl b/JuliaBUGS/experiments/scripts/gmm_correctness_sweep.jl new file mode 100644 index 000000000..1be97ef9d --- /dev/null +++ b/JuliaBUGS/experiments/scripts/gmm_correctness_sweep.jl @@ -0,0 +1,90 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs +JuliaBUGS.@bugs_primitive Categorical Normal +using LogDensityProblems +using LogExpFunctions: logsumexp + +function simulate_gmm(rng::AbstractRNG, N::Int, K::Int; weights, means, sigmas) + z = Vector{Int}(undef, N) + y = Vector{Float64}(undef, N) + for i in 1:N + z[i] = rand(rng, Categorical(weights)) + y[i] = rand(rng, Normal(means[z[i]], sigmas[z[i]])) + end + return (; z, y) +end + +function closed_form_logp(y, weights, means, sigmas) + N = length(y) + K = length(weights) + log_weights = log.(weights) + logvals = zeros(N) + for i in 1:N + comps = similar(log_weights) + for k in 1:K + comps[k] = log_weights[k] + logpdf(Normal(means[k], sigmas[k]), y[i]) + end + logvals[i] = logsumexp(comps) + end + return sum(logvals) +end + +function build_gmm_model(N, K) + @bugs begin + for i in 1:N + z[i] ~ Categorical(weights) + y[i] ~ Normal(means[z[i]], sigmas[z[i]]) + end + end +end + +function default_params(K) + weights = fill(1.0 / K, K) + means = collect(range(-2.0, 2.0; length=K)) + sigmas = fill(0.9, K) + return weights, means, sigmas +end + +function run_case(; K::Int, N::Int, seed::Int) + rng = MersenneTwister(seed) + weights, means, sigmas = default_params(K) + sim = simulate_gmm(rng, N, K; weights=weights, means=means, sigmas=sigmas) + + model_def = build_gmm_model(N, K) + data = ( + N = N, + K = K, + y = sim.y, + weights = weights, + means = means, + sigmas = sigmas, + ) + + model, θ0 = compile_autmarg(model_def, data) + logp = Base.invokelatest(LogDensityProblems.logdensity, model, θ0) + logp_ref = closed_form_logp(sim.y, weights, means, sigmas) + return logp, logp_ref +end + +seed_str = strip(get(ENV, "AG_SWEEP_SEEDS", "1")) +seed_vals = isempty(seed_str) ? [1] : parse.(Int, split(seed_str, ',')) +K_str = strip(get(ENV, "AG_SWEEP_K", "2,4")) +Ks = parse.(Int, split(K_str, ',')) +N_str = strip(get(ENV, "AG_SWEEP_N", "100,1000")) +Ns = parse.(Int, split(N_str, ',')) + +@printf "# GMM correctness sweep\n" +@printf "# seed,K,N,logp_autmarg,logp_closed_form,diff\n" +for seed in seed_vals, K in Ks, N in Ns + logp, logp_ref = run_case(K=K, N=N, seed=seed) + diff = logp - logp_ref + @printf "%d,%d,%d,%.12f,%.12f,%.3e\n" seed K N logp logp_ref diff +end diff --git a/JuliaBUGS/experiments/scripts/gmm_gradient_check.jl b/JuliaBUGS/experiments/scripts/gmm_gradient_check.jl new file mode 100644 index 000000000..446111f65 --- /dev/null +++ b/JuliaBUGS/experiments/scripts/gmm_gradient_check.jl @@ -0,0 +1,154 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf +using ForwardDiff + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs, getparams +JuliaBUGS.@bugs_primitive Categorical Normal exp +using LogDensityProblems + +parse_list(str) = begin + s = strip(str) + isempty(s) && return Int[] + if occursin(',', s) + return parse.(Int, split(s, ",")) + else + return [parse(Int, s)] + end +end + +seeds = let v = get(ENV, "AGG_SWEEP_SEEDS", get(ENV, "AGG_SEED", "1")) + xs = parse_list(v) + isempty(xs) ? [1] : xs +end +Ks = let v = get(ENV, "AGG_SWEEP_K", get(ENV, "AGG_K", "2")) + xs = parse_list(v) + isempty(xs) ? [2] : xs +end +Ns = let v = get(ENV, "AGG_SWEEP_N", get(ENV, "AGG_N", "200")) + xs = parse_list(v) + isempty(xs) ? [200] : xs +end + +eps = try parse(Float64, get(ENV, "AGG_EPS", "1e-5")) catch; 1e-5 end +verbose = get(ENV, "AGG_VERBOSE", "0") == "1" + +function simulate_gmm(rng, N, K; weights, means, sigmas) + z = Vector{Int}(undef, N) + y = Vector{Float64}(undef, N) + for i in 1:N + z[i] = rand(rng, Categorical(weights)) + y[i] = rand(rng, Normal(means[z[i]], sigmas[z[i]])) + end + return (; z, y) +end + +function default_gmm_params(K) + weights = fill(1.0 / K, K) + means_true = collect(range(-1.5, 1.5; length=K)) + sigmas_true = fill(0.7, K) + return weights, means_true, sigmas_true +end + +function priors_from_truth(means_true, sigmas_true) + mu_prior_mean = means_true + mu_prior_std = 1.0 + logsigma_prior_mean = log.(sigmas_true) + logsigma_prior_std = 0.3 + return mu_prior_mean, mu_prior_std, logsigma_prior_mean, logsigma_prior_std +end + +function gmm_param_model() + @bugs begin + for k in 1:K + mu[k] ~ Normal(mu_prior_mean[k], mu_prior_std) + log_sigma[k] ~ Normal(logsigma_prior_mean[k], logsigma_prior_std) + sigma[k] = exp(log_sigma[k]) + end + + for i in 1:N + z[i] ~ Categorical(weights) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end +end + +function finite_difference(f, x; ϵ=1e-5) + g = similar(x) + fx = f(x) + for i in eachindex(x) + xi = x[i] + x[i] = xi + ϵ + fp = f(x) + x[i] = xi - ϵ + fm = f(x) + x[i] = xi + g[i] = (fp - fm) / (2ϵ) + end + return g, fx +end + +function run_case(seed::Int, K::Int, N::Int; ϵ::Float64=1e-5, verbose::Bool=false) + rng = MersenneTwister(seed) + weights, means_true, sigmas_true = default_gmm_params(K) + sim = simulate_gmm(rng, N, K; weights=weights, means=means_true, sigmas=sigmas_true) + + mu_prior_mean, mu_prior_std, logsigma_prior_mean, logsigma_prior_std = priors_from_truth(means_true, sigmas_true) + + data = ( + N = N, + K = K, + y = sim.y, + weights = weights, + mu_prior_mean = mu_prior_mean, + mu_prior_std = mu_prior_std, + logsigma_prior_mean = logsigma_prior_mean, + logsigma_prior_std = logsigma_prior_std, + ) + + model, θ0 = compile_autmarg(gmm_param_model(), data) + target(θ) = Base.invokelatest(LogDensityProblems.logdensity, model, θ) + autograd = ForwardDiff.gradient(target, θ0) + fdgrad, logp = finite_difference(target, copy(θ0); ϵ=ϵ) + + diffs = autograd .- fdgrad + max_abs_diff = maximum(abs, diffs) + denom = map(i -> max(max(abs(autograd[i]), abs(fdgrad[i])), 1e-12), eachindex(θ0)) + rel_diffs = abs.(diffs) ./ denom + max_rel_diff = maximum(rel_diffs) + + if verbose + @printf "# GMM gradient check\n" + @printf "seed=%d, K=%d, N=%d\n" seed K N + @printf "logp = %.12f\n" logp + for i in eachindex(θ0) + @printf "θ[%d]: autodiff=%.6e fd=%.6e diff=%.2e rel=%.2e\n" i autograd[i] fdgrad[i] diffs[i] rel_diffs[i] + end + end + + return logp, max_abs_diff, max_rel_diff +end + +is_single = (length(seeds) == 1) && (length(Ks) == 1) && (length(Ns) == 1) + +if !is_single + @printf "# GMM gradient sweep\n" + @printf "# seed,K,N,max_abs_diff,max_rel_diff,logp\n" +end + +for seed in seeds, K in Ks, N in Ns + logp, max_abs_diff, max_rel_diff = run_case(seed, K, N; ϵ=eps, verbose=(verbose && is_single)) + if is_single && !verbose + @printf "# GMM gradient check\n" + @printf "seed=%d, K=%d, N=%d\n" seed K N + @printf "logp = %.12f\n" logp + @printf "max_abs_diff=%.3e, max_rel_diff=%.3e\n" max_abs_diff max_rel_diff + elseif !is_single + @printf "%d,%d,%d,%.3e,%.3e,%.12f\n" seed K N max_abs_diff max_rel_diff logp + end +end diff --git a/JuliaBUGS/experiments/scripts/hdphmm_correctness.jl b/JuliaBUGS/experiments/scripts/hdphmm_correctness.jl new file mode 100644 index 000000000..5e1e1dcdc --- /dev/null +++ b/JuliaBUGS/experiments/scripts/hdphmm_correctness.jl @@ -0,0 +1,166 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs + +# Deterministic helper to add kappa mass to one index of a vector +function diagshift(alpha_beta::AbstractVector{T}, i::Integer, kappa::Real) where {T} + K = length(alpha_beta) + out = Vector{T}(undef, K) + @inbounds for j in 1:K + out[j] = alpha_beta[j] + end + if 1 <= i <= K + @inbounds out[i] = out[i] + kappa + end + return out +end + +JuliaBUGS.@bugs_primitive Categorical Normal Beta Dirichlet diagshift +using LogDensityProblems +using LogExpFunctions: logsumexp + +parse_list(str) = begin + s = strip(str) + isempty(s) && return Int[] + if occursin(',', s) + return parse.(Int, split(s, ",")) + else + return [parse(Int, s)] + end +end + +seeds = let v = get(ENV, "AHDPC_SEEDS", "1") + xs = parse_list(v) + isempty(xs) ? [1] : xs +end +Ks = let v = get(ENV, "AHDPC_K", "5,10") + xs = parse_list(v) + isempty(xs) ? [5, 10] : xs +end +Ts = let v = get(ENV, "AHDPC_T", "100,200") + xs = parse_list(v) + isempty(xs) ? [100, 200] : xs +end + +alpha = try parse(Float64, get(ENV, "AHDPC_ALPHA", "5.0")) catch; 5.0 end +gamma = try parse(Float64, get(ENV, "AHDPC_GAMMA", "1.0")) catch; 1.0 end +kappa = try parse(Float64, get(ENV, "AHDPC_KAPPA", "0.0")) catch; 0.0 end + +function stick_break(v::AbstractVector) + K = length(v) + 1 + beta = similar(v, Float64, K) + stick = 1.0 + for k in 1:K-1 + beta[k] = v[k] * stick + stick *= (1 - v[k]) + end + beta[K] = stick + return beta +end + +function simulate_hdp_params(rng::AbstractRNG, K::Int; α::Real, γ::Real, κ::Real) + v = [rand(rng, Beta(1.0, γ)) for _ in 1:K-1] + beta = stick_break(v) + pi = Array{Float64}(undef, K, K) + for i in 1:K + a = α .* beta + a[i] += κ + pi[i, :] = rand(rng, Dirichlet(a)) + end + return beta, pi +end + +function simulate_hmm(rng::AbstractRNG, T::Int, K::Int; rho, pi, means, sigmas) + z = Vector{Int}(undef, T) + y = Vector{Float64}(undef, T) + z[1] = rand(rng, Categorical(rho)) + y[1] = rand(rng, Normal(means[z[1]], sigmas[z[1]])) + for t in 2:T + z[t] = rand(rng, Categorical(pi[z[t - 1], :])) + y[t] = rand(rng, Normal(means[z[t]], sigmas[z[t]])) + end + return (; z, y) +end + +function default_emission_params(K) + means = collect(range(-1.5, 1.5; length=K)) + sigmas = fill(0.6, K) + return means, sigmas +end + +function forward_logp(obs, rho, pi, means, sigmas) + T = length(obs) + K = length(rho) + log_emissions = Array{Float64}(undef, T, K) + for t in 1:T, k in 1:K + log_emissions[t, k] = logpdf(Normal(means[k], sigmas[k]), obs[t]) + end + log_pi = log.(pi) + log_alpha = Vector{Float64}(undef, K) + for k in 1:K + log_alpha[k] = log(rho[k]) + log_emissions[1, k] + end + tmp = similar(log_alpha) + for t in 2:T + for k in 1:K + tmp[k] = log_emissions[t, k] + logsumexp(log_alpha .+ log_pi[:, k]) + end + log_alpha, tmp = tmp, log_alpha + end + return logsumexp(log_alpha) +end + +function build_model() + @bugs begin + # Simple HMM model with pre-computed sticky HDP-HMM parameters + # The init_probs and transition matrices are passed as data (already include sticky kappa) + z[1] ~ Categorical(init_probs) + y[1] ~ Normal(means[z[1]], sigmas[z[1]]) + for t in 2:T + z[t] ~ Categorical(transition[z[t - 1], 1:K]) + y[t] ~ Normal(means[z[t]], sigmas[z[t]]) + end + end +end + +function run_case(; K::Int, T::Int, seed::Int, α::Float64, γ::Float64, κ::Float64) + rng = MersenneTwister(seed) + + # Generate parameters using sticky HDP-HMM structure + beta, pi = simulate_hdp_params(rng, K; α=α, γ=γ, κ=κ) + means, sigmas = default_emission_params(K) + + # Simulate data from sticky HDP-HMM + sim = simulate_hmm(rng, T, K; rho=beta, pi=pi, means=means, sigmas=sigmas) + + model_def = build_model() + data = ( + T = T, + K = K, + y = sim.y, + init_probs = beta, # Initial state distribution from stick-breaking + transition = pi, # Transition matrices with sticky kappa + means = means, + sigmas = sigmas, + ) + model, θ0 = compile_autmarg(model_def, data) + logp = Base.invokelatest(LogDensityProblems.logdensity, model, θ0) + logp_ref = forward_logp(sim.y, beta, pi, means, sigmas) + return logp, logp_ref +end + +@printf "# HDP-HMM correctness (sticky κ=%.2f, data simulated from sticky HDP-HMM)\n" kappa +@printf "# seed,K,T,alpha,gamma,kappa,logp_autmarg,logp_forward,diff\n" +for seed in seeds, K in Ks, T in Ts + logp, logp_ref = run_case(K=K, T=T, seed=seed, α=alpha, γ=gamma, κ=kappa) + diff = logp - logp_ref + @printf "%d,%d,%d,%.3f,%.3f,%.3f,%.12f,%.12f,%.3e\n" seed K T alpha gamma kappa logp logp_ref diff +end + diff --git a/JuliaBUGS/experiments/scripts/hdphmm_gradient_check.jl b/JuliaBUGS/experiments/scripts/hdphmm_gradient_check.jl new file mode 100644 index 000000000..9fe55e83f --- /dev/null +++ b/JuliaBUGS/experiments/scripts/hdphmm_gradient_check.jl @@ -0,0 +1,217 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf +using ForwardDiff + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs + +# Deterministic helper to add kappa mass to one index of a vector +function diagshift(alpha_beta::AbstractVector{T}, i::Integer, kappa::Real) where {T} + K = length(alpha_beta) + out = Vector{T}(undef, K) + @inbounds for j in 1:K + out[j] = alpha_beta[j] + end + if 1 <= i <= K + @inbounds out[i] = out[i] + kappa + end + return out +end + +# Allow these functions in @bugs +JuliaBUGS.@bugs_primitive Categorical Normal Beta Dirichlet exp diagshift +using LogDensityProblems + +parse_list(str) = begin + s = strip(str) + isempty(s) && return Int[] + if occursin(',', s) + return parse.(Int, split(s, ",")) + else + return [parse(Int, s)] + end +end + +seeds = let v = get(ENV, "AHDPG_SWEEP_SEEDS", get(ENV, "AHDPG_SEED", "1")) + xs = parse_list(v) + isempty(xs) ? [1] : xs +end +Ks = let v = get(ENV, "AHDPG_SWEEP_K", get(ENV, "AHDPG_K", "5")) + xs = parse_list(v) + isempty(xs) ? [5] : xs +end +Ts = let v = get(ENV, "AHDPG_SWEEP_T", get(ENV, "AHDPG_T", "200")) + xs = parse_list(v) + isempty(xs) ? [200] : xs +end + +alpha = try parse(Float64, get(ENV, "AHDPG_ALPHA", "5.0")) catch; 5.0 end +gamma = try parse(Float64, get(ENV, "AHDPG_GAMMA", "1.0")) catch; 1.0 end +kappa = try parse(Float64, get(ENV, "AHDPG_KAPPA", "0.0")) catch; 0.0 end + +eps = try parse(Float64, get(ENV, "AHDPG_EPS", "1e-5")) catch; 1e-5 end +verbose = get(ENV, "AHDPG_VERBOSE", "0") == "1" + +# Emission priors +m0 = try parse(Float64, get(ENV, "AHDPG_MU0", "0.0")) catch; 0.0 end +s0 = try parse(Float64, get(ENV, "AHDPG_MU0_STD", "1.0")) catch; 1.0 end +ℓ0 = try parse(Float64, get(ENV, "AHDPG_LOGSIGMA0", string(log(0.6)))) catch; log(0.6) end +τ0 = try parse(Float64, get(ENV, "AHDPG_LOGSIGMA0_STD", "0.3")) catch; 0.3 end + +function stick_break(v::AbstractVector) + K = length(v) + 1 + beta = similar(v, Float64, K) + stick = 1.0 + for k in 1:K-1 + beta[k] = v[k] * stick + stick *= (1 - v[k]) + end + beta[K] = stick + return beta +end + +function simulate_hdp_params(rng::AbstractRNG, K::Int; α::Real, γ::Real, κ::Real) + v = [rand(rng, Beta(1.0, γ)) for _ in 1:K-1] + beta = stick_break(v) + pi = Array{Float64}(undef, K, K) + for i in 1:K + a = α .* beta + a[i] += κ + pi[i, :] = rand(rng, Dirichlet(a)) + end + return beta, pi +end + +function simulate_y(rng::AbstractRNG, T::Int, K::Int; rho, pi, mu, sigma) + z = Vector{Int}(undef, T) + y = Vector{Float64}(undef, T) + z[1] = rand(rng, Categorical(rho)) + y[1] = rand(rng, Normal(mu[z[1]], sigma[z[1]])) + for t in 2:T + z[t] = rand(rng, Categorical(pi[z[t - 1], :])) + y[t] = rand(rng, Normal(mu[z[t]], sigma[z[t]])) + end + return y +end + +function hdphmm_param_model() + @bugs begin + # Emissions + for k in 1:K + mu[k] ~ Normal(mu0, mu0_std) + log_sigma[k] ~ Normal(logsigma0, logsigma0_std) + sigma[k] = exp(log_sigma[k]) + end + + # Stick-breaking weights + for k in 1:(K - 1) + v[k] ~ Beta(1.0, gamma) + end + # Deterministic stick-breaking to beta without reassigning scalars + remain[1] = 1.0 + for k in 1:(K - 1) + beta[k] = v[k] * remain[k] + remain[k + 1] = remain[k] * (1.0 - v[k]) + end + beta[K] = remain[K] + + # Transition rows with HDP weak-limit prior + sticky kappa via diagshift + for j in 1:K + alpha_beta[j] = alpha * beta[j] + end + for i in 1:K + transition[i, 1:K] ~ Dirichlet(diagshift(alpha_beta[1:K], i, kappa)) + end + + # Initial distribution = beta (standard HDP-HMM choice) + z[1] ~ Categorical(beta[1:K]) + y[1] ~ Normal(mu[z[1]], sigma[z[1]]) + for t in 2:T + z[t] ~ Categorical(transition[z[t - 1], 1:K]) + y[t] ~ Normal(mu[z[t]], sigma[z[t]]) + end + end +end + +function finite_difference(f, x; ϵ=1e-5) + g = similar(x) + fx = f(x) + for i in eachindex(x) + xi = x[i] + x[i] = xi + ϵ + fp = f(x) + x[i] = xi - ϵ + fm = f(x) + x[i] = xi + g[i] = (fp - fm) / (2ϵ) + end + return g, fx +end + +function run_case(seed::Int, K::Int, T::Int; α::Float64, γ::Float64, κ::Float64, ϵ::Float64, verbose::Bool) + rng = MersenneTwister(seed) + # Simulate from prior for data generation + mu_true = collect(range(-1.5, 1.5; length=K)) + sigma_true = fill(0.6, K) + beta_true, pi_true = simulate_hdp_params(rng, K; α=α, γ=γ, κ=κ) + y = simulate_y(rng, T, K; rho=beta_true, pi=pi_true, mu=mu_true, sigma=sigma_true) + + data = ( + T = T, + K = K, + y = y, + alpha = α, + gamma = γ, + kappa = κ, + mu0 = m0, + mu0_std = s0, + logsigma0 = ℓ0, + logsigma0_std = τ0, + ) + + model, θ0 = compile_autmarg(hdphmm_param_model(), data) + target(θ) = Base.invokelatest(LogDensityProblems.logdensity, model, θ) + autograd = ForwardDiff.gradient(target, θ0) + fdgrad, logp = finite_difference(target, copy(θ0); ϵ=ϵ) + + diffs = autograd .- fdgrad + max_abs_diff = maximum(abs, diffs) + denom = map(i -> max(max(abs(autograd[i]), abs(fdgrad[i])), 1e-12), eachindex(θ0)) + rel_diffs = abs.(diffs) ./ denom + max_rel_diff = maximum(rel_diffs) + + if verbose + @printf "# HDP-HMM gradient check\n" + @printf "seed=%d, K=%d, T=%d\n" seed K T + @printf "logp = %.12f\n" logp + for i in eachindex(θ0) + @printf "θ[%d]: autodiff=%.6e fd=%.6e diff=%.2e rel=%.2e\n" i autograd[i] fdgrad[i] diffs[i] rel_diffs[i] + end + end + + return logp, max_abs_diff, max_rel_diff +end + +is_single = (length(seeds) == 1) && (length(Ks) == 1) && (length(Ts) == 1) + +if !is_single + @printf "# HDP-HMM gradient sweep (sticky κ=%.2f)\n" kappa + @printf "# seed,K,T,max_abs_diff,max_rel_diff,logp\n" +end + +for seed in seeds, K in Ks, T in Ts + logp, max_abs_diff, max_rel_diff = run_case(seed, K, T; α=alpha, γ=gamma, κ=kappa, ϵ=eps, verbose=(verbose && is_single)) + if is_single && !verbose + @printf "# HDP-HMM gradient check (sticky κ=%.2f)\n" kappa + @printf "seed=%d, K=%d, T=%d\n" seed K T + @printf "logp = %.12f\n" logp + @printf "max_abs_diff=%.3e, max_rel_diff=%.3e\n" max_abs_diff max_rel_diff + elseif !is_single + @printf "%d,%d,%d,%.3e,%.3e,%.12f\n" seed K T max_abs_diff max_rel_diff logp + end +end diff --git a/JuliaBUGS/experiments/scripts/hmm_correctness_sweep.jl b/JuliaBUGS/experiments/scripts/hmm_correctness_sweep.jl new file mode 100644 index 000000000..0f7169137 --- /dev/null +++ b/JuliaBUGS/experiments/scripts/hmm_correctness_sweep.jl @@ -0,0 +1,117 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs +JuliaBUGS.@bugs_primitive Categorical Normal +using LogDensityProblems +using LogExpFunctions: logsumexp + +function simulate_hmm(rng::AbstractRNG, T::Int, K::Int; init_probs, transition, means, sigmas) + states = Vector{Int}(undef, T) + obs = Vector{Float64}(undef, T) + + states[1] = rand(rng, Categorical(init_probs)) + obs[1] = rand(rng, Normal(means[states[1]], sigmas[states[1]])) + + for t in 2:T + prev = states[t - 1] + states[t] = rand(rng, Categorical(transition[prev, :])) + obs[t] = rand(rng, Normal(means[states[t]], sigmas[states[t]])) + end + + return (; states, obs) +end + +function forward_logp(obs, init_probs, transition, means, sigmas) + T = length(obs) + K = length(init_probs) + + log_emissions = Array{Float64}(undef, T, K) + for t in 1:T, k in 1:K + log_emissions[t, k] = logpdf(Normal(means[k], sigmas[k]), obs[t]) + end + + log_transition = log.(transition) + log_alpha = Vector{Float64}(undef, K) + for k in 1:K + log_alpha[k] = log(init_probs[k]) + log_emissions[1, k] + end + + tmp = similar(log_alpha) + for t in 2:T + for k in 1:K + tmp[k] = log_emissions[t, k] + logsumexp(log_alpha .+ log_transition[:, k]) + end + log_alpha, tmp = tmp, log_alpha + end + + return logsumexp(log_alpha) +end + +function build_hmm_model() + @bugs begin + z[1] ~ Categorical(init_probs) + y[1] ~ Normal(means[z[1]], sigmas[z[1]]) + + for t in 2:T + z[t] ~ Categorical(transition[z[t - 1], 1:K]) + y[t] ~ Normal(means[z[t]], sigmas[z[t]]) + end + end +end + +function default_params(K) + init_probs = fill(1.0 / K, K) + diag = 0.85 + off = (1.0 - diag) / (K - 1) + transition = fill(off, K, K) + for k in 1:K + transition[k, k] = diag + end + means = collect(range(-1.5, 1.5; length=K)) + sigmas = fill(0.6, K) + return init_probs, transition, means, sigmas +end + +function run_case(; K::Int, T::Int, seed::Int) + rng = MersenneTwister(seed) + init_probs, transition, means, sigmas = default_params(K) + sim = simulate_hmm(rng, T, K; init_probs=init_probs, transition=transition, means=means, sigmas=sigmas) + + model_def = build_hmm_model() + data = ( + T = T, + K = K, + y = sim.obs, + init_probs = init_probs, + transition = transition, + means = means, + sigmas = sigmas, + ) + + model, θ0 = compile_autmarg(model_def, data) + logp = Base.invokelatest(LogDensityProblems.logdensity, model, θ0) + logp_ref = forward_logp(sim.obs, init_probs, transition, means, sigmas) + return logp, logp_ref +end + +seed_str = strip(get(ENV, "AM_SWEEP_SEEDS", "1")) +seed_vals = isempty(seed_str) ? [1] : parse.(Int, split(seed_str, ',')) +K_str = strip(get(ENV, "AM_SWEEP_K", "2,4")) +Ks = parse.(Int, split(K_str, ',')) +T_str = strip(get(ENV, "AM_SWEEP_T", "50,200")) +Ts = parse.(Int, split(T_str, ',')) + +@printf "# HMM correctness sweep\n" +@printf "# seed,K,T,logp_autmarg,logp_forward,diff\n" +for seed in seed_vals, K in Ks, T in Ts + logp, logp_ref = run_case(K=K, T=T, seed=seed) + diff = logp - logp_ref + @printf "%d,%d,%d,%.12f,%.12f,%.3e\n" seed K T logp logp_ref diff +end diff --git a/JuliaBUGS/experiments/scripts/hmm_gradient_check.jl b/JuliaBUGS/experiments/scripts/hmm_gradient_check.jl new file mode 100644 index 000000000..36c73b6c0 --- /dev/null +++ b/JuliaBUGS/experiments/scripts/hmm_gradient_check.jl @@ -0,0 +1,166 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf +using ForwardDiff + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs, getparams +JuliaBUGS.@bugs_primitive Categorical Normal exp +using LogDensityProblems + +parse_list(str) = begin + s = strip(str) + isempty(s) && return Int[] + if occursin(',', s) + return parse.(Int, split(s, ",")) + else + return [parse(Int, s)] + end +end + +seeds = let v = get(ENV, "AGC_SWEEP_SEEDS", get(ENV, "AGC_SEED", "1")) + xs = parse_list(v) + isempty(xs) ? [1] : xs +end +Ks = let v = get(ENV, "AGC_SWEEP_K", get(ENV, "AGC_K", "2")) + xs = parse_list(v) + isempty(xs) ? [2] : xs +end +Ts = let v = get(ENV, "AGC_SWEEP_T", get(ENV, "AGC_T", "50")) + xs = parse_list(v) + isempty(xs) ? [50] : xs +end + +eps = try parse(Float64, get(ENV, "AGC_EPS", "1e-5")) catch; 1e-5 end +verbose = get(ENV, "AGC_VERBOSE", "0") == "1" +function default_hmm_params(K) + means_true = collect(range(-1.0, 1.0; length=K)) + sigmas_true = fill(0.6, K) + return means_true, sigmas_true +end + +function simulate_hmm(rng, T, K; init_probs, transition, means, sigmas) + z = Vector{Int}(undef, T) + y = Vector{Float64}(undef, T) + z[1] = rand(rng, Categorical(init_probs)) + y[1] = rand(rng, Normal(means[z[1]], sigmas[z[1]])) + for t in 2:T + z[t] = rand(rng, Categorical(transition[z[t - 1], :])) + y[t] = rand(rng, Normal(means[z[t]], sigmas[z[t]])) + end + return (; z, y) +end + +function priors_from_truth(means_true, sigmas_true) + mu_prior_mean = means_true + mu_prior_std = 1.0 + logsigma_prior_mean = log.(sigmas_true) + logsigma_prior_std = 0.3 + return mu_prior_mean, mu_prior_std, logsigma_prior_mean, logsigma_prior_std +end + +function hmm_param_model() + @bugs begin + for k in 1:K + mu[k] ~ Normal(mu_prior_mean[k], mu_prior_std) + log_sigma[k] ~ Normal(logsigma_prior_mean[k], logsigma_prior_std) + sigma[k] = exp(log_sigma[k]) + end + + z[1] ~ Categorical(init_probs) + y[1] ~ Normal(mu[z[1]], sigma[z[1]]) + + for t in 2:T + z[t] ~ Categorical(transition[z[t - 1], 1:K]) + y[t] ~ Normal(mu[z[t]], sigma[z[t]]) + end + end +end + +function finite_difference(f, x; ϵ=1e-5) + g = similar(x) + fx = f(x) + for i in eachindex(x) + xi = x[i] + x[i] = xi + ϵ + fp = f(x) + x[i] = xi - ϵ + fm = f(x) + x[i] = xi + g[i] = (fp - fm) / (2ϵ) + end + return g, fx +end + +function run_case(seed::Int, K::Int, T::Int; ϵ::Float64=1e-5, verbose::Bool=false) + rng = MersenneTwister(seed) + init_probs = fill(1.0 / K, K) + diag = 0.85 + off = (1 - diag) / (K - 1) + transition = fill(off, K, K) + for k in 1:K + transition[k, k] = diag + end + means_true, sigmas_true = default_hmm_params(K) + sim = simulate_hmm(rng, T, K; init_probs=init_probs, transition=transition, means=means_true, sigmas=sigmas_true) + + mu_prior_mean, mu_prior_std, logsigma_prior_mean, logsigma_prior_std = priors_from_truth(means_true, sigmas_true) + + data = ( + T = T, + K = K, + y = sim.y, + init_probs = init_probs, + transition = transition, + mu_prior_mean = mu_prior_mean, + mu_prior_std = mu_prior_std, + logsigma_prior_mean = logsigma_prior_mean, + logsigma_prior_std = logsigma_prior_std, + ) + + model, θ0 = compile_autmarg(hmm_param_model(), data) + target(θ) = Base.invokelatest(LogDensityProblems.logdensity, model, θ) + autograd = ForwardDiff.gradient(target, θ0) + fdgrad, logp = finite_difference(target, copy(θ0); ϵ=ϵ) + + diffs = autograd .- fdgrad + max_abs_diff = maximum(abs, diffs) + denom = map(i -> max(max(abs(autograd[i]), abs(fdgrad[i])), 1e-12), eachindex(θ0)) + rel_diffs = abs.(diffs) ./ denom + max_rel_diff = maximum(rel_diffs) + + if verbose + @printf "# HMM gradient check\n" + @printf "seed=%d, K=%d, T=%d\n" seed K T + @printf "logp = %.12f\n" logp + for i in eachindex(θ0) + @printf "θ[%d]: autodiff=%.6e fd=%.6e diff=%.2e rel=%.2e\n" i autograd[i] fdgrad[i] diffs[i] rel_diffs[i] + end + end + + return logp, max_abs_diff, max_rel_diff +end + +# Decide whether this is a sweep or single run +is_single = (length(seeds) == 1) && (length(Ks) == 1) && (length(Ts) == 1) + +if !is_single + @printf "# HMM gradient sweep\n" + @printf "# seed,K,T,max_abs_diff,max_rel_diff,logp\n" +end + +for seed in seeds, K in Ks, T in Ts + logp, max_abs_diff, max_rel_diff = run_case(seed, K, T; ϵ=eps, verbose=(verbose && is_single)) + if is_single && !verbose + @printf "# HMM gradient check\n" + @printf "seed=%d, K=%d, T=%d\n" seed K T + @printf "logp = %.12f\n" logp + @printf "max_abs_diff=%.3e, max_rel_diff=%.3e\n" max_abs_diff max_rel_diff + elseif !is_single + @printf "%d,%d,%d,%.3e,%.3e,%.12f\n" seed K T max_abs_diff max_rel_diff logp + end +end diff --git a/JuliaBUGS/experiments/scripts/hmm_marginal_logp.jl b/JuliaBUGS/experiments/scripts/hmm_marginal_logp.jl new file mode 100644 index 000000000..19d8d5426 --- /dev/null +++ b/JuliaBUGS/experiments/scripts/hmm_marginal_logp.jl @@ -0,0 +1,97 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs +JuliaBUGS.@bugs_primitive Categorical Normal +using LogDensityProblems +using LogExpFunctions: logsumexp + +function simulate_hmm(rng::AbstractRNG, T::Int; init_probs, transition, means, sigmas) + states = Vector{Int}(undef, T) + obs = Vector{Float64}(undef, T) + + states[1] = rand(rng, Categorical(init_probs)) + obs[1] = rand(rng, Normal(means[states[1]], sigmas[states[1]])) + + for t in 2:T + prev = states[t - 1] + states[t] = rand(rng, Categorical(transition[prev, :])) + obs[t] = rand(rng, Normal(means[states[t]], sigmas[states[t]])) + end + + return (; states, obs) +end + +function forward_logp(obs::AbstractVector, init_probs, transition, means, sigmas) + T = length(obs) + K = length(init_probs) + + log_emissions = Array{Float64}(undef, T, K) + for t in 1:T, k in 1:K + log_emissions[t, k] = logpdf(Normal(means[k], sigmas[k]), obs[t]) + end + + log_transition = log.(transition) + log_alpha = Vector{Float64}(undef, K) + for k in 1:K + log_alpha[k] = log(init_probs[k]) + log_emissions[1, k] + end + + tmp = similar(log_alpha) + for t in 2:T + for k in 1:K + tmp[k] = log_emissions[t, k] + logsumexp(log_alpha .+ log_transition[:, k]) + end + log_alpha, tmp = tmp, log_alpha + end + + return logsumexp(log_alpha) +end + +function build_hmm_model() + @bugs begin + z[1] ~ Categorical(init_probs) + y[1] ~ Normal(means[z[1]], sigmas[z[1]]) + + for t in 2:T + z[t] ~ Categorical(transition[z[t - 1], 1:K]) + y[t] ~ Normal(means[z[t]], sigmas[z[t]]) + end + end +end + +rng = MersenneTwister(get(ENV, "AM_SEED", "1") |> x -> parse(Int, x)) +T = get(ENV, "AM_T", "50") |> x -> parse(Int, x) +K = 2 +init_probs = [0.6, 0.4] +transition = [0.95 0.05; 0.10 0.90] +means = [-1.0, 1.5] +sigmas = fill(0.4, K) + +sim = simulate_hmm(rng, T; init_probs=init_probs, transition=transition, means=means, sigmas=sigmas) + +model_def = build_hmm_model() +data = ( + T = T, + K = K, + y = sim.obs, + init_probs = init_probs, + transition = transition, + means = means, + sigmas = sigmas, +) + +model, θ0 = compile_autmarg(model_def, data) +logp = Base.invokelatest(LogDensityProblems.logdensity, model, θ0) + +logp_ref = forward_logp(sim.obs, init_probs, transition, means, sigmas) +diff = logp - logp_ref + +@printf "Marginalized log probability (T=%d): %.6f\n" T logp +@printf "Forward reference log probability: %.6f (Δ = %.2e)\n" logp_ref diff diff --git a/JuliaBUGS/experiments/scripts/hmm_scaling_bench.jl b/JuliaBUGS/experiments/scripts/hmm_scaling_bench.jl new file mode 100644 index 000000000..855ab62c2 --- /dev/null +++ b/JuliaBUGS/experiments/scripts/hmm_scaling_bench.jl @@ -0,0 +1,131 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf +using Statistics +using BenchmarkTools + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs +JuliaBUGS.@bugs_primitive Categorical Normal +using LogDensityProblems + +const JModel = JuliaBUGS.Model + +parse_list(str) = begin + s = strip(str) + isempty(s) && return Int[] + if occursin(',', s) + return parse.(Int, split(s, ",")) + else + return [parse(Int, s)] + end +end + +seeds = let v = get(ENV, "AS_SWEEP_SEEDS", get(ENV, "AS_SEED", "1")) + xs = parse_list(v) + isempty(xs) ? [1] : xs +end +Ks = let v = get(ENV, "AS_SWEEP_K", get(ENV, "AS_K", "2,4")) + xs = parse_list(v) + isempty(xs) ? [2, 4] : xs +end +Ts = let v = get(ENV, "AS_SWEEP_T", get(ENV, "AS_T", "50,200")) + xs = parse_list(v) + isempty(xs) ? [50, 200] : xs +end +trials = try parse(Int, get(ENV, "AS_TRIALS", "5")) catch; 5 end + +function simulate_hmm(rng::AbstractRNG, T::Int, K::Int; init_probs, transition, means, sigmas) + states = Vector{Int}(undef, T) + obs = Vector{Float64}(undef, T) + + states[1] = rand(rng, Categorical(init_probs)) + obs[1] = rand(rng, Normal(means[states[1]], sigmas[states[1]])) + + for t in 2:T + prev = states[t - 1] + states[t] = rand(rng, Categorical(transition[prev, :])) + obs[t] = rand(rng, Normal(means[states[t]], sigmas[states[t]])) + end + + return (; states, obs) +end + +function build_hmm_model() + @bugs begin + z[1] ~ Categorical(init_probs) + y[1] ~ Normal(means[z[1]], sigmas[z[1]]) + + for t in 2:T + z[t] ~ Categorical(transition[z[t - 1], 1:K]) + y[t] ~ Normal(means[z[t]], sigmas[z[t]]) + end + end +end + +function default_params(K) + init_probs = fill(1.0 / K, K) + diag = 0.85 + off = (1.0 - diag) / (K - 1) + transition = fill(off, K, K) + for k in 1:K + transition[k, k] = diag + end + means = collect(range(-1.5, 1.5; length=K)) + sigmas = fill(0.6, K) + return init_probs, transition, means, sigmas +end + +function frontier_stats(model) + gd = model.graph_evaluation_data + order = gd.marginalization_order + keys = gd.minimal_cache_keys + widths = [length(get(keys, idx, Int[])) for idx in order] + if isempty(widths) + return 0, 0.0, 0 + end + return maximum(widths), mean(widths), sum(widths) +end + +function bench_case(; K::Int, T::Int, seed::Int, trials::Int) + rng = MersenneTwister(seed) + init_probs, transition, means, sigmas = default_params(K) + sim = simulate_hmm(rng, T, K; init_probs=init_probs, transition=transition, means=means, sigmas=sigmas) + + model_def = build_hmm_model() + data = ( + T = T, + K = K, + y = sim.obs, + init_probs = init_probs, + transition = transition, + means = means, + sigmas = sigmas, + ) + + model, θ0 = compile_autmarg(model_def, data) + # Always use the interleaved (time-first) order for scaling + model = make_model_with_order(model, build_interleaved_order(model)) + # Warm-up JIT + _ = Base.invokelatest(LogDensityProblems.logdensity, model, θ0) + _ = Base.invokelatest(LogDensityProblems.logdensity, model, θ0) + + # Benchmark with BenchmarkTools; use provided trial count (reports min time) + mean_time = @belapsed Base.invokelatest(LogDensityProblems.logdensity, $model, $θ0) samples=trials evals=1 + + max_frontier, mean_frontier, sum_frontier = frontier_stats(model) + # Return mean time, last logp, and frontier stats + logp = Base.invokelatest(LogDensityProblems.logdensity, model, θ0) + return mean_time, logp, max_frontier, mean_frontier, sum_frontier +end + +@printf "# HMM scaling benchmark (auto-marginalization)\n" +@printf "# seed,K,T,trials,min_time_sec,logp,max_frontier,mean_frontier,sum_frontier\n" +for seed in seeds, K in Ks, T in Ts + mean_time, logp, max_f, mean_f, sum_f = bench_case(K=K, T=T, seed=seed, trials=trials) + @printf "%d,%d,%d,%d,%.6e,%.12f,%d,%.3f,%d\n" seed K T trials mean_time logp max_f mean_f sum_f +end diff --git a/JuliaBUGS/experiments/scripts/hmt_order_comparison.jl b/JuliaBUGS/experiments/scripts/hmt_order_comparison.jl new file mode 100644 index 000000000..e23d416a3 --- /dev/null +++ b/JuliaBUGS/experiments/scripts/hmt_order_comparison.jl @@ -0,0 +1,141 @@ +#!/usr/bin/env julia + +using Random +using Distributions +using Printf +using Statistics +using BenchmarkTools + +include(joinpath(@__DIR__, "..", "utils.jl")) + +using JuliaBUGS +using JuliaBUGS: @bugs +JuliaBUGS.@bugs_primitive Categorical Normal +using LogDensityProblems + +# ========================== +# Env configuration +# ========================== + +B = try parse(Int, get(ENV, "AHMT_B", "2")) catch; 2 end +depth = try parse(Int, get(ENV, "AHMT_DEPTH", "8")) catch; 8 end +K = try parse(Int, get(ENV, "AHMT_K", "4")) catch; 4 end +seed = try parse(Int, get(ENV, "AHMT_SEED", "1")) catch; 1 end +trials = try parse(Int, get(ENV, "AHMT_TRIALS", "10")) catch; 10 end +mode = lowercase(get(ENV, "AHMT_MODE", "frontier")) # frontier | timed | dfs +cost_thresh = try parse(Float64, get(ENV, "AHMT_COST_THRESH", "1.0e8")) catch; 1.0e8 end + +function num_nodes(B::Int, depth::Int) + if depth <= 0 + return 0 + end + if B == 1 + return depth + else + return (B^depth - 1) ÷ (B - 1) + end +end + +function parent_index(i::Int, B::Int) + i == 1 && return 0 + return fld(i - 2, B) + 1 +end + +function default_params(K) + init_probs = fill(1.0 / K, K) + diag = 0.85 + off = (1.0 - diag) / (K - 1) + transition = fill(off, K, K) + for k in 1:K + transition[k, k] = diag + end + means = collect(range(-1.5, 1.5; length=K)) + sigmas = fill(0.6, K) + return init_probs, transition, means, sigmas +end + +function simulate_hmt(rng::AbstractRNG, B::Int, depth::Int, K::Int; init_probs, transition, means, sigmas) + N = num_nodes(B, depth) + z = Vector{Int}(undef, N) + y = Vector{Float64}(undef, N) + + z[1] = rand(rng, Categorical(init_probs)) + y[1] = rand(rng, Normal(means[z[1]], sigmas[z[1]])) + for i in 2:N + p = parent_index(i, B) + z[i] = rand(rng, Categorical(transition[z[p], :])) + y[i] = rand(rng, Normal(means[z[i]], sigmas[z[i]])) + end + return (; z, y) +end + +function hmt_model() + @bugs begin + z[1] ~ Categorical(init_probs) + y[1] ~ Normal(means[z[1]], sigmas[z[1]]) + for i in 2:N + z[i] ~ Categorical(transition[z[parent[i]], 1:K]) + y[i] ~ Normal(means[z[i]], sigmas[z[i]]) + end + end +end + +function compile_case(B, depth, K, seed) + rng = MersenneTwister(seed) + N = num_nodes(B, depth) + init_probs, transition, means, sigmas = default_params(K) + sim = simulate_hmt(rng, B, depth, K; init_probs=init_probs, transition=transition, means=means, sigmas=sigmas) + + parent = [parent_index(i, B) for i in 1:N] + data = ( + B = B, + depth = depth, + N = N, + K = K, + y = sim.y, + init_probs = init_probs, + transition = transition, + means = means, + sigmas = sigmas, + parent = parent, + ) + return compile_autmarg(hmt_model(), data) +end + +function frontier_cost_proxy_for(model; K_hint::Real) + max_f, mean_f, sum_f, logproxy = frontier_cost_proxy(model; K_hint=K_hint) + return max_f, mean_f, sum_f, logproxy +end + +model, θ0 = compile_case(B, depth, K, seed) + +orders = Dict{String,Function}( + "dfs" => () -> make_model_with_order(model, build_hmt_dfs_order(model; B_hint=B)), + "bfs" => () -> make_model_with_order(model, build_hmt_bfs_order(model)), + "random_dfs" => () -> make_model_with_order(model, build_hmt_dfs_order(model; B_hint=B, rng=MersenneTwister(seed+1), randomized=true)), + "min_fill" => () -> make_model_with_order(model, build_min_fill_order(model; rng=MersenneTwister(seed+2), num_restarts=3)), + "min_degree" => () -> make_model_with_order(model, build_min_degree_order(model; rng=MersenneTwister(seed+3), num_restarts=3)), +) + +@printf "# HMT order comparison (B=%d, depth=%d, K=%d)\n" B depth K +@printf "# order,B,K,depth,N,max_frontier,mean_frontier,sum_frontier,log_cost_proxy,min_time_sec,logp\n" +for (name, buildfun) in orders + m2 = buildfun() + max_f, mean_f, sum_f, logproxy = frontier_cost_proxy_for(m2; K_hint=K) + do_time = (mode == "timed") || (name == "dfs" && mode != "frontier") + tmin = NaN + logp = NaN + if do_time && logproxy <= log(cost_thresh) + _ = Base.invokelatest(LogDensityProblems.logdensity, m2, θ0) + _ = Base.invokelatest(LogDensityProblems.logdensity, m2, θ0) + tmin = @belapsed Base.invokelatest(LogDensityProblems.logdensity, $m2, $θ0) samples=trials evals=1 + logp = Base.invokelatest(LogDensityProblems.logdensity, m2, θ0) + end + # Number of nodes in tree + N = num_nodes(B, depth) + @printf "%s,%d,%d,%d,%d,%d,%.3f,%d,%.3e,%s,%s\n" name B K depth N max_f mean_f sum_f logproxy ( + isnan(tmin) ? "NA" : @sprintf("%.6e", tmin) + ) ( + isnan(logp) ? "NA" : @sprintf("%.12f", logp) + ) +end diff --git a/JuliaBUGS/experiments/utils.jl b/JuliaBUGS/experiments/utils.jl new file mode 100644 index 000000000..69e075fbd --- /dev/null +++ b/JuliaBUGS/experiments/utils.jl @@ -0,0 +1,768 @@ +using LogDensityProblems +using JuliaBUGS +using JuliaBUGS: compile +using AbstractPPL +using MetaGraphsNext +using Random +using LogExpFunctions + +const JModel = JuliaBUGS.Model + +""" + compile_autmarg(model_def, data; transformed=true) + +Compile a BUGS model, set transformed mode, and UseAutoMarginalization. +Returns the model and a zero vector of appropriate dimension. +""" +function compile_autmarg(model_def, data; transformed=true) + m = compile(model_def, data) + m = JModel.settrans(m, transformed) + m = JModel.set_evaluation_mode(m, JModel.UseAutoMarginalization()) + if !(m.evaluation_mode isa JModel.UseAutoMarginalization) + error( + "Auto-marginalization mode was not activated (got $(typeof(m.evaluation_mode))). " * + "Ensure the model can precompute marginalization caches before running experiments.", + ) + end + D = LogDensityProblems.dimension(m) + return m, zeros(D) +end + +""" + build_interleaved_order(model) + +For models with paired latent/observations like HMMs or mixture models where +names contain `z[i]` and `y[i]`, return an order that keeps non (z,y) nodes first +and then interleaves `z[i], y[i]` by ascending i. Falls back to `sorted_nodes` +when names are not present. +""" +function build_interleaved_order(model) + gd = model.graph_evaluation_data + z_idxs = Dict{Int,Int}() + y_idxs = Dict{Int,Int}() + other = Int[] + for (j, vn) in enumerate(gd.sorted_nodes) + s = string(vn) + if startswith(s, "z[") + i = try parse(Int, s[3:end-1]) catch; -1 end + if i > 0; z_idxs[i] = j; else; push!(other, j); end + elseif startswith(s, "y[") + i = try parse(Int, s[3:end-1]) catch; -1 end + if i > 0; y_idxs[i] = j; else; push!(other, j); end + else + push!(other, j) + end + end + order = copy(other) + if !isempty(z_idxs) + for i in sort(collect(keys(z_idxs))) + zi = z_idxs[i] + push!(order, zi) + yi = get(y_idxs, i, 0) + if yi != 0; push!(order, yi); end + end + end + return order +end + +""" + prepare_minimal_cache_keys(model, order) + +Wrapper around `JuliaBUGS.Model._precompute_minimal_cache_keys` returning a Dict. +""" +function prepare_minimal_cache_keys(model, order::AbstractVector{<:Integer}) + return JModel._precompute_minimal_cache_keys(model, collect(order)) +end + +""" + build_fhmm_interleaved_order(model) + +For Factorial HMMs with names like `z[c,t]` and `y[t]`, return an order that +keeps non (z,y) nodes first and then, for t=1:T, lists `z[1,t], z[2,t], …, z[C,t], y[t]`. +Falls back to `sorted_nodes` when names are not present. +""" +function build_fhmm_interleaved_order(model) + gd = model.graph_evaluation_data + # Collect indices for discrete variables only (z[c,t]) + z_idxs = Dict{Tuple{Int,Int},Int}() # (c,t) -> node idx + max_c, max_t = 0, 0 + for (j, vn) in enumerate(gd.sorted_nodes) + s = string(vn) + if startswith(s, "z[") + # Extract indices inside brackets, e.g., "c,t" + inner = replace(s[3:end-1], " " => "") + parts = split(inner, ',') + if length(parts) == 2 + c = try parse(Int, parts[1]) catch; -1 end + t = try parse(Int, parts[2]) catch; -1 end + if c > 0 && t > 0 + z_idxs[(c, t)] = j + max_c = max(max_c, c) + max_t = max(max_t, t) + end + end + end + end + # Build discrete-first interleaved-by-time order: for each time, z[1,t],...,z[C,t] + disc_order = Int[] + for t in 1:max_t + for c in 1:max_c + idx = get(z_idxs, (c, t), 0) + if idx != 0 + push!(disc_order, idx) + end + end + end + # Lift to full evaluation order (places emissions y[t] when their discrete parents are ready) + return build_eval_order_from_discrete_order(model, disc_order) +end + +""" + build_fhmm_states_first_order(model) + +For FHMMs, return an order that lists all `z[c,t]` in increasing t and c first, +then all `y[t]`. This demonstrates poor ordering (frontier explosion) while +preserving topological constraints across time. +""" +function build_fhmm_states_first_order(model) + gd = model.graph_evaluation_data + z_idxs = Dict{Tuple{Int,Int},Int}() + max_c, max_t = 0, 0 + for (j, vn) in enumerate(gd.sorted_nodes) + s = string(vn) + if startswith(s, "z[") + inner = replace(s[3:end-1], " " => "") + parts = split(inner, ',') + if length(parts) == 2 + c = try parse(Int, parts[1]) catch; -1 end + t = try parse(Int, parts[2]) catch; -1 end + if c > 0 && t > 0 + z_idxs[(c, t)] = j + max_c = max(max_c, c) + max_t = max(max_t, t) + end + end + end + end + # Discrete states-first order (by time then chain) + disc_order = Int[] + for t in 1:max_t + for c in 1:max_c + idx = get(z_idxs, (c, t), 0) + if idx != 0 + push!(disc_order, idx) + end + end + end + return build_eval_order_from_discrete_order(model, disc_order) +end + +# ========================== +# Heuristic Order Construction +# ========================== + +## Helper previously used by heuristics removed to simplify experiments module +""" + build_eval_order_from_discrete_order(model, disc_order) + +Lift a discrete-variable elimination order to a full evaluation order by: +- Placing dependencies (direct parents) before each discrete var +- Placing observed nodes as soon as all their discrete parents are placed +- Topologically repairing remaining nodes +""" +function build_eval_order_from_discrete_order(model, disc_order::AbstractVector{<:Integer}) + gd = model.graph_evaluation_data + order_nodes = gd.sorted_nodes + n = length(order_nodes) + pos = Dict(order_nodes[i] => i for i in 1:n) + placed = fill(false, n) + out = Int[] + + parents(vn) = collect(MetaGraphsNext.inneighbor_labels(model.g, vn)) + + function place_with_dependencies(vn) + i = pos[vn] + if placed[i]; return; end + for p in parents(vn) + place_with_dependencies(p) + end + push!(out, i) + placed[i] = true + end + + # Precompute discrete parents of observed stochastic nodes + st_parents = JModel._get_stochastic_parents_indices(model) + obs_nodes = [i for i in 1:n if gd.is_stochastic_vals[i] && gd.is_observed_vals[i]] + obs_disc_parents = Dict{Int,Vector{Int}}() + for j in obs_nodes + ps = [p for p in st_parents[j] if gd.is_discrete_finite_vals[p] && !gd.is_observed_vals[p]] + obs_disc_parents[j] = ps + end + + # Place discrete vars following disc_order, then any observed nodes whose discrete parents are all placed + for idx in disc_order + place_with_dependencies(order_nodes[idx]) + # Place ready observed nodes + for j in obs_nodes + if !placed[j] + ps = obs_disc_parents[j] + if all(placed[p] for p in ps) + place_with_dependencies(order_nodes[j]) + end + end + end + end + + # Finally, place any remaining nodes + for vn in order_nodes + if !placed[pos[vn]] + place_with_dependencies(vn) + end + end + return out +end + +## Heuristic helpers removed from experiments to keep focus on consistent orders + +""" + frontier_cost_proxy(model; K_hint=2) + +Compute frontier statistics and a domain-aware proxy cost Σ_t K_hint^{width_t}. +Returns (max_width, mean_width, sum_width, proxy). +""" +function frontier_cost_proxy(model; K_hint::Real=2) + gd = model.graph_evaluation_data + order = gd.marginalization_order + keys = gd.minimal_cache_keys + widths = [length(get(keys, idx, Int[])) for idx in order] + if isempty(widths) + return 0, 0.0, 0, -Inf + end + logK = log(float(K_hint)) + # Stable log-sum-exp of w*logK + log_terms = (w * logK for w in widths) + proxy_log = LogExpFunctions.logsumexp(log_terms) + return maximum(widths), mean(widths), sum(widths), proxy_log +end + +""" + topo_repair_order(model, desired_order) + +Given a desired ordering of node indices (w.r.t. `model.graph_evaluation_data.sorted_nodes`), +return a topologically valid order that preserves the desired sequence as much as possible +while ensuring all direct parents are placed before each node. +""" +function topo_repair_order(model, desired_order::AbstractVector{<:Integer}) + gd = model.graph_evaluation_data + n = length(gd.sorted_nodes) + desired = collect(desired_order) + # Extend with any nodes not explicitly listed + if length(desired) < n + present = Set(desired) + append!(desired, (i for i in 1:n if i ∉ present)) + end + + pos = Dict(gd.sorted_nodes[i] => i for i in 1:n) + placed = fill(false, n) + out = Int[] + + parents(vn) = collect(MetaGraphsNext.inneighbor_labels(model.g, vn)) + + function place_with_dependencies(vn) + i = pos[vn] + if placed[i] + return + end + for p in parents(vn) + place_with_dependencies(p) + end + push!(out, i) + placed[i] = true + end + + for idx in desired + vn = gd.sorted_nodes[idx] + place_with_dependencies(vn) + end + # Safety: ensure all nodes placed + for vn in gd.sorted_nodes + place_with_dependencies(vn) + end + return out +end + +""" + make_model_with_order(model, order) + +Return a new BUGSModel whose `graph_evaluation_data` carries the provided +topologically‑repaired marginalization order and the corresponding +`minimal_cache_keys` computed for that order. +""" +function make_model_with_order(model, order::AbstractVector{<:Integer}) + gd = model.graph_evaluation_data + n = length(gd.sorted_nodes) + # Repair order to satisfy direct parent dependencies + repaired = topo_repair_order(model, order) + # Compute minimal cache keys for this evaluation order + min_keys = JModel._precompute_minimal_cache_keys(model, repaired) + # Build a new GraphEvaluationData reusing cached fields but with new order/keys + gd2 = JModel.GraphEvaluationData{typeof(gd.node_function_vals),typeof(gd.loop_vars_vals)}( + gd.sorted_nodes, + gd.sorted_parameters, + gd.is_stochastic_vals, + gd.is_observed_vals, + gd.node_function_vals, + gd.loop_vars_vals, + gd.node_types, + gd.is_discrete_finite_vals, + min_keys, + repaired, + ) + # Return a shallow‑copy model with updated GraphEvaluationData + return JModel.BUGSModel(model; graph_evaluation_data=gd2) +end + +""" + build_states_first_order(model) + +Construct an order that places all non (z,y) nodes first, then all z[i] by i, +then all y[i] by i. Intended for HMM‑style models and for demonstrating poor +ordering effects. +""" +function build_states_first_order(model) + gd = model.graph_evaluation_data + z_idxs = Dict{Int,Int}() + y_idxs = Dict{Int,Int}() + other = Int[] + for (j, vn) in enumerate(gd.sorted_nodes) + s = string(vn) + if startswith(s, "z[") + i = try parse(Int, s[3:end-1]) catch; -1 end + if i > 0 + z_idxs[i] = j + else + push!(other, j) + end + elseif startswith(s, "y[") + i = try parse(Int, s[3:end-1]) catch; -1 end + if i > 0 + y_idxs[i] = j + else + push!(other, j) + end + else + push!(other, j) + end + end + order = copy(other) + for i in sort!(collect(keys(z_idxs))) + push!(order, z_idxs[i]) + end + for i in sort!(collect(keys(y_idxs))) + push!(order, y_idxs[i]) + end + return order +end + +""" + build_fhmm_states_then_emissions_order(model) + +Construct a consistent but poor order for FHMMs: place all non-(z,y) nodes first, +then all discrete states z[c,t] (by increasing t, then c), followed by all +emissions y[t]. This tends to maximize frontier width across time and +demonstrates intractability when ordering is poor. +""" +function build_fhmm_states_then_emissions_order(model) + gd = model.graph_evaluation_data + z_idxs = Dict{Tuple{Int,Int},Int}() + y_idxs = Dict{Int,Int}() + other = Int[] + max_c, max_t = 0, 0 + for (j, vn) in enumerate(gd.sorted_nodes) + s = string(vn) + if startswith(s, "z[") + inner = replace(s[3:end-1], " " => "") + parts = split(inner, ',') + if length(parts) == 2 + c = try parse(Int, parts[1]) catch; -1 end + t = try parse(Int, parts[2]) catch; -1 end + if c > 0 && t > 0 + z_idxs[(c, t)] = j + max_c = max(max_c, c) + max_t = max(max_t, t) + end + end + elseif startswith(s, "y[") + i = try parse(Int, s[3:end-1]) catch; -1 end + if i > 0 + y_idxs[i] = j + else + push!(other, j) + end + else + push!(other, j) + end + end + order = copy(other) + # All z's first (by time then chain) + for t in 1:max_t + for c in 1:max_c + idx = get(z_idxs, (c, t), 0) + if idx != 0 + push!(order, idx) + end + end + end + # Then all y's (by time) + for t in sort(collect(keys(y_idxs))) + push!(order, y_idxs[t]) + end + return order +end + +""" + build_hmt_dfs_order(model; B_hint=2, rng=nothing, randomized=false) + +Construct a discrete-first DFS order for HMTs (variables named `z[i]`, `y[i]`). +Discovers the tree structure directly from the model graph by following edges +between `z` nodes. If `randomized=true`, shuffles child order using `rng`. +Returns a full evaluation order via `build_eval_order_from_discrete_order`. +""" +function build_hmt_dfs_order(model; B_hint::Integer=2, rng=nothing, randomized::Bool=false) + gd = model.graph_evaluation_data + nodes = gd.sorted_nodes + n = length(nodes) + # Map VarName->sorted index and VarName->z index i + pos = Dict(nodes[i] => i for i in 1:n) + z_index = Dict{typeof(nodes[1]),Int}() + for (j, vn) in enumerate(nodes) + s = string(vn) + if startswith(s, "z[") + i = try parse(Int, s[3:end-1]) catch; -1 end + if i > 0 + z_index[vn] = i + end + end + end + # Build adjacency among z nodes via out-neighbors + children = Dict{Int,Vector{Int}}() + has_parent = Dict{Int,Bool}() + for (vn, i) in z_index + kids = Int[] + for w in MetaGraphsNext.outneighbor_labels(model.g, vn) + if haskey(z_index, w) + j = z_index[w] + push!(kids, j) + has_parent[j] = true + end + end + children[i] = kids + has_parent[i] = get(has_parent, i, false) + end + # Roots: z nodes without z-parents + roots = sort([i for (i, hp) in has_parent if !hp]) + order_z = Int[] # sorted indices (positions) for z nodes + # Helper: get VarName for a given z index i + z_vn_by_i = Dict{Int,typeof(nodes[1])}() + for (vn, i) in z_index + z_vn_by_i[i] = vn + end + function dfs(i::Int) + push!(order_z, pos[z_vn_by_i[i]]) + kids = get(children, i, Int[]) + if randomized && rng !== nothing + Random.shuffle!(rng, kids) + else + sort!(kids) + end + for j in kids + dfs(j) + end + end + for r in roots + dfs(r) + end + return build_eval_order_from_discrete_order(model, order_z) +end + +""" + build_hmt_bfs_order(model) + +Construct a discrete-first BFS order for HMTs by level, discovered from the +graph. Returns a full evaluation order after lifting. +""" +function build_hmt_bfs_order(model) + gd = model.graph_evaluation_data + nodes = gd.sorted_nodes + n = length(nodes) + pos = Dict(nodes[i] => i for i in 1:n) + z_index = Dict{typeof(nodes[1]),Int}() + for (j, vn) in enumerate(nodes) + s = string(vn) + if startswith(s, "z[") + i = try parse(Int, s[3:end-1]) catch; -1 end + if i > 0 + z_index[vn] = i + end + end + end + children = Dict{Int,Vector{Int}}() + has_parent = Dict{Int,Bool}() + for (vn, i) in z_index + kids = Int[] + for w in MetaGraphsNext.outneighbor_labels(model.g, vn) + if haskey(z_index, w) + j = z_index[w] + push!(kids, j) + has_parent[j] = true + end + end + children[i] = kids + has_parent[i] = get(has_parent, i, false) + end + roots = sort([i for (i, hp) in has_parent if !hp]) + # BFS + order_z = Int[] + z_vn_by_i = Dict{Int,typeof(nodes[1])}() + for (vn, i) in z_index + z_vn_by_i[i] = vn + end + queue = copy(roots) + while !isempty(queue) + i = popfirst!(queue) + push!(order_z, pos[z_vn_by_i[i]]) + kids = sort(get(children, i, Int[])) + append!(queue, kids) + end + return build_eval_order_from_discrete_order(model, order_z) +end + +# ========================== +# Graph Heuristic Orderings +# ========================== + +""" + build_primal_graph(model) + +Build the primal graph (moralized graph) for discrete finite unobserved variables. +Returns: +- adj: adjacency list mapping label index -> Set of neighbor label indices +- discrete_vars: sorted list of discrete finite unobserved variable label indices +""" +function build_primal_graph(model) + gd = model.graph_evaluation_data + n = length(gd.sorted_nodes) + st_parents = JModel._get_stochastic_parents_indices(model) + + # Identify discrete finite unobserved variables + discrete_vars = Int[] + for i in 1:n + if gd.is_discrete_finite_vals[i] && !gd.is_observed_vals[i] && gd.is_stochastic_vals[i] + push!(discrete_vars, i) + end + end + sort!(discrete_vars) + + # Build adjacency: connect variables that appear together in any stochastic parent set + adj = Dict{Int,Set{Int}}(i => Set{Int}() for i in discrete_vars) + + for j in 1:n + if gd.is_stochastic_vals[j] + # Get discrete finite unobserved parents of this node + disc_parents = [p for p in st_parents[j] if p in discrete_vars] + # Moralize: connect all pairs of parents + for i in 1:length(disc_parents) + for k in (i+1):length(disc_parents) + p1, p2 = disc_parents[i], disc_parents[k] + push!(adj[p1], p2) + push!(adj[p2], p1) + end + end + end + end + + return adj, discrete_vars +end + +""" + count_fill_edges(adj, var, eliminated) + +Count the number of fill edges that would be created by eliminating `var`. +Fill edges connect neighbors of `var` that are not already connected. +""" +function count_fill_edges(adj::Dict{Int,Set{Int}}, var::Int, eliminated::Set{Int}) + neighbors = setdiff(adj[var], eliminated) + fill_count = 0 + neighbors_vec = collect(neighbors) + for i in 1:length(neighbors_vec) + for j in (i+1):length(neighbors_vec) + n1, n2 = neighbors_vec[i], neighbors_vec[j] + if n2 ∉ adj[n1] + fill_count += 1 + end + end + end + return fill_count +end + +""" + eliminate_variable!(adj, var, eliminated) + +Eliminate `var` from the graph by connecting all its remaining neighbors (creating fill edges). +Updates `adj` in place and adds `var` to `eliminated`. +""" +function eliminate_variable!(adj::Dict{Int,Set{Int}}, var::Int, eliminated::Set{Int}) + neighbors = setdiff(adj[var], eliminated) + neighbors_vec = collect(neighbors) + # Add fill edges: connect all pairs of neighbors + for i in 1:length(neighbors_vec) + for j in (i+1):length(neighbors_vec) + n1, n2 = neighbors_vec[i], neighbors_vec[j] + push!(adj[n1], n2) + push!(adj[n2], n1) + end + end + push!(eliminated, var) + return nothing +end + +""" + build_min_fill_order(model; rng=nothing, num_restarts=3) + +Construct a discrete elimination order using the min-fill heuristic with randomized tie-breaking. +Greedily eliminates the variable that creates the fewest fill edges. +When ties occur, breaks them randomly and runs multiple restarts, returning the best order. + +# Arguments +- `model`: The BUGSModel +- `rng`: Random number generator for tie-breaking (default: MersenneTwister(42)) +- `num_restarts`: Number of randomized restarts when ties occur (default: 3) +""" +function build_min_fill_order(model; rng=nothing, num_restarts=3) + if rng === nothing + rng = Random.MersenneTwister(42) + end + + best_order = Int[] + best_cost = Inf + + for restart in 1:num_restarts + adj, discrete_vars = build_primal_graph(model) + eliminated = Set{Int}() + order = Int[] + + while length(order) < length(discrete_vars) + remaining = setdiff(discrete_vars, eliminated) + + # Find variables with minimum fill + min_fill = typemax(Int) + candidates = Int[] + + for var in remaining + fill = count_fill_edges(adj, var, eliminated) + if fill < min_fill + min_fill = fill + candidates = [var] + elseif fill == min_fill + push!(candidates, var) + end + end + + # Break ties randomly + chosen = if length(candidates) == 1 + candidates[1] + else + candidates[rand(rng, 1:length(candidates))] + end + + push!(order, chosen) + eliminate_variable!(adj, chosen, eliminated) + end + + # Evaluate cost using frontier proxy + test_model = make_model_with_order(model, build_eval_order_from_discrete_order(model, order)) + _, _, _, log_cost = frontier_cost_proxy(test_model; K_hint=2) + + if log_cost < best_cost + best_cost = log_cost + best_order = order + end + end + + return build_eval_order_from_discrete_order(model, best_order) +end + +""" + count_degree(adj, var, eliminated) + +Count the degree of `var` (number of remaining neighbors). +""" +function count_degree(adj::Dict{Int,Set{Int}}, var::Int, eliminated::Set{Int}) + return length(setdiff(adj[var], eliminated)) +end + +""" + build_min_degree_order(model; rng=nothing, num_restarts=3) + +Construct a discrete elimination order using the min-degree heuristic with randomized tie-breaking. +Greedily eliminates the variable with the smallest degree (fewest remaining neighbors). +When ties occur, breaks them randomly and runs multiple restarts, returning the best order. + +# Arguments +- `model`: The BUGSModel +- `rng`: Random number generator for tie-breaking (default: MersenneTwister(42)) +- `num_restarts`: Number of randomized restarts when ties occur (default: 3) +""" +function build_min_degree_order(model; rng=nothing, num_restarts=3) + if rng === nothing + rng = Random.MersenneTwister(42) + end + + best_order = Int[] + best_cost = Inf + + for restart in 1:num_restarts + adj, discrete_vars = build_primal_graph(model) + eliminated = Set{Int}() + order = Int[] + + while length(order) < length(discrete_vars) + remaining = setdiff(discrete_vars, eliminated) + + # Find variables with minimum degree + min_degree = typemax(Int) + candidates = Int[] + + for var in remaining + deg = count_degree(adj, var, eliminated) + if deg < min_degree + min_degree = deg + candidates = [var] + elseif deg == min_degree + push!(candidates, var) + end + end + + # Break ties randomly + chosen = if length(candidates) == 1 + candidates[1] + else + candidates[rand(rng, 1:length(candidates))] + end + + push!(order, chosen) + eliminate_variable!(adj, chosen, eliminated) + end + + # Evaluate cost using frontier proxy + test_model = make_model_with_order(model, build_eval_order_from_discrete_order(model, order)) + _, _, _, log_cost = frontier_cost_proxy(test_model; K_hint=2) + + if log_cost < best_cost + best_cost = log_cost + best_order = order + end + end + + return build_eval_order_from_discrete_order(model, best_order) +end diff --git a/JuliaBUGS/src/JuliaBUGS.jl b/JuliaBUGS/src/JuliaBUGS.jl index e340d8470..bc9f79022 100644 --- a/JuliaBUGS/src/JuliaBUGS.jl +++ b/JuliaBUGS/src/JuliaBUGS.jl @@ -42,7 +42,8 @@ using .Model: BUGSModel, evaluate_with_values!!, UseGraph, - UseGeneratedLogDensityFunction + UseGeneratedLogDensityFunction, + UseAutoMarginalization include("independent_mh.jl") include("gibbs.jl") diff --git a/JuliaBUGS/src/model/Model.jl b/JuliaBUGS/src/model/Model.jl index 37ca24aa8..c67d70507 100644 --- a/JuliaBUGS/src/model/Model.jl +++ b/JuliaBUGS/src/model/Model.jl @@ -9,6 +9,7 @@ using Graphs using LinearAlgebra using JuliaBUGS: JuliaBUGS, BUGSGraph using JuliaBUGS.BUGSPrimitives +using LogExpFunctions using MetaGraphsNext using Random @@ -21,5 +22,8 @@ include("logdensityproblems.jl") export parameters, variables, initialize!, getparams, settrans, set_evaluation_mode export regenerate_log_density_function, set_observed_values! export evaluate_with_rng!!, evaluate_with_env!!, evaluate_with_values!! +export evaluate_with_marginalization_rng!!, + evaluate_with_marginalization_env!!, evaluate_with_marginalization_values!! +export UseAutoMarginalization, enumerate_discrete_values end # Model diff --git a/JuliaBUGS/src/model/abstractppl.jl b/JuliaBUGS/src/model/abstractppl.jl index 5524ef222..d34f5ade4 100644 --- a/JuliaBUGS/src/model/abstractppl.jl +++ b/JuliaBUGS/src/model/abstractppl.jl @@ -577,7 +577,7 @@ function _create_modified_model( # Recompute mutable symbols for the new graph new_mutable_symbols = get_mutable_symbols(updated_graph_evaluation_data) - # Create the new model with all updated fields + # Create the new model with all updated fields (without auto-marg caches yet) kwargs = Dict{Symbol,Any}( :untransformed_param_length => new_untransformed_param_length, :transformed_param_length => new_transformed_param_length, @@ -599,7 +599,34 @@ function _create_modified_model( kwargs[:base_model] = base_model end - return BUGSModel(model; kwargs...) + new_model = BUGSModel(model; kwargs...) + + # Compute and attach auto-marg caches once for the new graph + try + order = JuliaBUGS.Model._compute_marginalization_order(new_model) + keys = JuliaBUGS.Model._precompute_minimal_cache_keys(new_model, order) + + gd = new_model.graph_evaluation_data + gd_cached = GraphEvaluationData{ + typeof(gd.node_function_vals),typeof(gd.loop_vars_vals) + }( + gd.sorted_nodes, + gd.sorted_parameters, + gd.is_stochastic_vals, + gd.is_observed_vals, + gd.node_function_vals, + gd.loop_vars_vals, + gd.node_types, + gd.is_discrete_finite_vals, + keys, + order, + ) + return BUGSModel(new_model; graph_evaluation_data=gd_cached) + catch err + @warn "Failed to precompute auto-marginalization caches; falling back to graph evaluation" exception=(err, catch_backtrace()) + # Ensure the regenerated model does not stay in an inconsistent evaluation mode + return BangBang.setproperty!!(new_model, :evaluation_mode, UseGraph()) + end end # Common helper function to regenerate log density function @@ -768,8 +795,14 @@ function evaluate!!( temperature=1.0, transformed=model.transformed, ) - evaluation_env, log_densities = evaluate_with_values!!( - model, flattened_values; temperature=temperature, transformed=transformed - ) + if model.evaluation_mode isa UseAutoMarginalization + evaluation_env, log_densities = evaluate_with_marginalization_values!!( + model, flattened_values; temperature=temperature, transformed=transformed + ) + else + evaluation_env, log_densities = evaluate_with_values!!( + model, flattened_values; temperature=temperature, transformed=transformed + ) + end return evaluation_env, log_densities.tempered_logjoint end diff --git a/JuliaBUGS/src/model/bugsmodel.jl b/JuliaBUGS/src/model/bugsmodel.jl index 2b1d92c25..4654ed5d8 100644 --- a/JuliaBUGS/src/model/bugsmodel.jl +++ b/JuliaBUGS/src/model/bugsmodel.jl @@ -3,6 +3,54 @@ # instead of https://github.com/TuringLang/AbstractMCMC.jl/blob/d7c549fe41a80c1f164423c7ac458425535f624b/src/logdensityproblems.jl#L90 abstract type AbstractBUGSModel end +""" + is_discrete_finite_distribution(dist) + +Check if a distribution is discrete with finite support. +""" +function is_discrete_finite_distribution(dist) + # Check if it's a discrete distribution first + if !(dist isa Distributions.DiscreteUnivariateDistribution) + return false + end + + # Whitelist of known finite discrete distributions + return dist isa Union{ + Distributions.Bernoulli, + Distributions.Binomial, + Distributions.Categorical, + Distributions.DiscreteUniform, + Distributions.BetaBinomial, + Distributions.Hypergeometric, + } +end + +""" + enumerate_discrete_values(dist) + +Return the finite support for a discrete univariate distribution. +Relies on Distributions.support to provide an iterable, finite range. +""" +enumerate_discrete_values(dist::Distributions.DiscreteUnivariateDistribution) = Distributions.support( + dist +) + +""" + classify_node_type(dist) + +Classify a distribution into node types for marginalization. +Returns one of: :deterministic, :discrete_finite, :discrete_infinite, :continuous +""" +function classify_node_type(dist) + if is_discrete_finite_distribution(dist) + return :discrete_finite + elseif dist isa Distributions.DiscreteUnivariateDistribution + return :discrete_infinite + else + return :continuous + end +end + """ GraphEvaluationData{TNF,TV} @@ -16,6 +64,8 @@ Stores pre-computed values to avoid repeated lookups from the MetaGraph during m - `is_observed_vals::Vector{Bool}`: Whether each node is observed (has data) - `node_function_vals::TNF`: Functions that define each node's computation - `loop_vars_vals::TV`: Loop variables associated with each node +- `node_types::Vector{Symbol}`: Node type classification (:deterministic, :discrete_finite, :discrete_infinite, :continuous) +- `is_discrete_finite_vals::Vector{Bool}`: Whether each node is a discrete variable with finite support """ struct GraphEvaluationData{TNF,TV} sorted_nodes::Vector{<:VarName} @@ -24,8 +74,18 @@ struct GraphEvaluationData{TNF,TV} is_observed_vals::Vector{Bool} node_function_vals::TNF loop_vars_vals::TV + node_types::Vector{Symbol} + is_discrete_finite_vals::Vector{Bool} + minimal_cache_keys::Dict{Int,Vector{Int}} + marginalization_order::Vector{Int} end +""" + GraphEvaluationData(compat constructor) + +Backward-compatible constructor that fills new caching fields with defaults +when older call sites provide only the first nine fields. +""" function GraphEvaluationData( g::BUGSGraph, sorted_nodes::Vector{<:VarName}=VarName[ @@ -37,6 +97,8 @@ function GraphEvaluationData( is_observed_vals = Array{Bool}(undef, length(sorted_nodes)) node_function_vals = Array{Any}(undef, length(sorted_nodes)) loop_vars_vals = Array{Any}(undef, length(sorted_nodes)) + node_types = Array{Symbol}(undef, length(sorted_nodes)) + is_discrete_finite_vals = Array{Bool}(undef, length(sorted_nodes)) sorted_parameters = VarName[] for (i, vn) in enumerate(sorted_nodes) @@ -46,6 +108,10 @@ function GraphEvaluationData( node_function_vals[i] = node_function loop_vars_vals[i] = loop_vars + # Default node types - will be updated during BUGSModel construction + node_types[i] = :continuous + is_discrete_finite_vals[i] = false + # If it's a stochastic variable and not observed, it's a parameter # If active_parameters is specified, only include those that are in the list if is_stochastic && !is_observed @@ -55,13 +121,17 @@ function GraphEvaluationData( end end - return GraphEvaluationData( + return GraphEvaluationData{typeof(node_function_vals),typeof(loop_vars_vals)}( sorted_nodes, sorted_parameters, is_stochastic_vals, is_observed_vals, map(identity, node_function_vals), map(identity, loop_vars_vals), + node_types, + is_discrete_finite_vals, + Dict{Int,Vector{Int}}(), + Int[], ) end @@ -69,6 +139,7 @@ abstract type EvaluationMode end struct UseGeneratedLogDensityFunction <: EvaluationMode end struct UseGraph <: EvaluationMode end +struct UseAutoMarginalization <: EvaluationMode end """ BUGSModel @@ -144,7 +215,8 @@ function BUGSModel( model_def::Expr=model.model_def, data=model.data, ) - return BUGSModel( + # Build an intermediate model + m = BUGSModel( model_def, data, g, @@ -160,6 +232,44 @@ function BUGSModel( mutable_symbols, base_model, ) + # Precompute minimal cache keys for current evaluation order if not present + gd = m.graph_evaluation_data + minimal_keys = if !isempty(gd.minimal_cache_keys) + gd.minimal_cache_keys + else + n = length(gd.sorted_nodes) + JuliaBUGS.Model._precompute_minimal_cache_keys(m, collect(1:n)) + end + # Attach minimal cache keys to GraphEvaluationData (order remains default) + gd2 = GraphEvaluationData{typeof(gd.node_function_vals),typeof(gd.loop_vars_vals)}( + gd.sorted_nodes, + gd.sorted_parameters, + gd.is_stochastic_vals, + gd.is_observed_vals, + gd.node_function_vals, + gd.loop_vars_vals, + gd.node_types, + gd.is_discrete_finite_vals, + minimal_keys, + gd.marginalization_order, + ) + # Return final model with cached minimal keys + return BUGSModel( + model_def, + data, + g, + evaluation_env, + transformed, + evaluation_mode, + untransformed_param_length, + transformed_param_length, + untransformed_var_lengths, + transformed_var_lengths, + gd2, + log_density_computation_function, + mutable_symbols, + base_model, + ) end function Base.show(io::IO, model::BUGSModel) @@ -219,6 +329,10 @@ function BUGSModel( untransformed_var_lengths, transformed_var_lengths = Dict{VarName,Int}(), Dict{VarName,Int}() + # Create mutable copies of node_types and is_discrete_finite_vals for updating + node_types = copy(graph_evaluation_data.node_types) + is_discrete_finite_vals = copy(graph_evaluation_data.is_discrete_finite_vals) + for (i, vn) in enumerate(graph_evaluation_data.sorted_nodes) is_stochastic = graph_evaluation_data.is_stochastic_vals[i] is_observed = graph_evaluation_data.is_observed_vals[i] @@ -226,31 +340,59 @@ function BUGSModel( loop_vars = graph_evaluation_data.loop_vars_vals[i] if !is_stochastic + # Deterministic node + node_types[i] = :deterministic + is_discrete_finite_vals[i] = false value = Base.invokelatest(node_function, evaluation_env, loop_vars) evaluation_env = BangBang.setindex!!(evaluation_env, value, vn) - elseif !is_observed + else + # Stochastic node - evaluate distribution and classify dist = Base.invokelatest(node_function, evaluation_env, loop_vars) - untransformed_var_lengths[vn] = length(dist) - # not all distributions are defined for `Bijectors.transformed` - transformed_var_lengths[vn] = if Bijectors.bijector(dist) == identity - untransformed_var_lengths[vn] - else - length(Bijectors.transformed(dist)) - end - untransformed_param_length += untransformed_var_lengths[vn] - transformed_param_length += transformed_var_lengths[vn] - - if haskey(initial_params, AbstractPPL.getsym(vn)) - initialization = AbstractPPL.get(initial_params, vn) - evaluation_env = BangBang.setindex!!(evaluation_env, initialization, vn) - else - init_value = rand(dist) - evaluation_env = BangBang.setindex!!(evaluation_env, init_value, vn) + # Classify the node type based on the distribution + node_types[i] = classify_node_type(dist) + is_discrete_finite_vals[i] = (node_types[i] == :discrete_finite) + + if !is_observed + # Unobserved stochastic node (parameter) + untransformed_var_lengths[vn] = length(dist) + # not all distributions are defined for `Bijectors.transformed` + transformed_var_lengths[vn] = if Bijectors.bijector(dist) == identity + untransformed_var_lengths[vn] + else + length(Bijectors.transformed(dist)) + end + untransformed_param_length += untransformed_var_lengths[vn] + transformed_param_length += transformed_var_lengths[vn] + + if haskey(initial_params, AbstractPPL.getsym(vn)) + initialization = AbstractPPL.get(initial_params, vn) + evaluation_env = BangBang.setindex!!(evaluation_env, initialization, vn) + else + init_value = rand(dist) + evaluation_env = BangBang.setindex!!(evaluation_env, init_value, vn) + end end end end + # Update graph_evaluation_data with the computed node types + graph_evaluation_data = GraphEvaluationData{ + typeof(graph_evaluation_data.node_function_vals), + typeof(graph_evaluation_data.loop_vars_vals), + }( + graph_evaluation_data.sorted_nodes, + graph_evaluation_data.sorted_parameters, + graph_evaluation_data.is_stochastic_vals, + graph_evaluation_data.is_observed_vals, + graph_evaluation_data.node_function_vals, + graph_evaluation_data.loop_vars_vals, + node_types, + is_discrete_finite_vals, + Dict{Int,Vector{Int}}(), + Int[], + ) + lowered_model_def, reconstructed_model_def = JuliaBUGS._generate_lowered_model_def( model_def, g, evaluation_env ) @@ -272,7 +414,48 @@ function BUGSModel( node in graph_evaluation_data.sorted_nodes end - graph_evaluation_data = GraphEvaluationData(g, sorted_nodes) + # Recreate GraphEvaluationData with the filtered sorted_nodes, but + # preserve previously computed node classifications. The earlier + # classification stored in `node_types` and `is_discrete_finite_vals` + # corresponds to `graph_evaluation_data.sorted_nodes` before filtering. + # A naive `GraphEvaluationData(g, sorted_nodes)` call would reset all + # node types to defaults, losing this information. + + # Build a mapping from VarName -> classification from the original order + old_nodes = graph_evaluation_data.sorted_nodes + type_map = Dict{VarName,Symbol}( + old_nodes[i] => node_types[i] for i in eachindex(old_nodes) + ) + disc_map = Dict{VarName,Bool}( + old_nodes[i] => is_discrete_finite_vals[i] for i in eachindex(old_nodes) + ) + + # Create a fresh GraphEvaluationData for the new order to reuse other fields + new_gd = GraphEvaluationData(g, sorted_nodes) + + # Remap classification arrays to the new order + new_node_types = Vector{Symbol}(undef, length(new_gd.sorted_nodes)) + new_is_discrete_finite_vals = Vector{Bool}(undef, length(new_gd.sorted_nodes)) + for (i, vn) in enumerate(new_gd.sorted_nodes) + new_node_types[i] = get(type_map, vn, :continuous) + new_is_discrete_finite_vals[i] = get(disc_map, vn, false) + end + + # Reconstruct GraphEvaluationData while preserving classification + graph_evaluation_data = GraphEvaluationData{ + typeof(new_gd.node_function_vals),typeof(new_gd.loop_vars_vals) + }( + new_gd.sorted_nodes, + new_gd.sorted_parameters, + new_gd.is_stochastic_vals, + new_gd.is_observed_vals, + new_gd.node_function_vals, + new_gd.loop_vars_vals, + new_node_types, + new_is_discrete_finite_vals, + Dict{Int,Vector{Int}}(), + Int[], + ) else log_density_computation_function = nothing end @@ -280,7 +463,8 @@ function BUGSModel( # Compute mutable symbols from graph evaluation data mutable_symbols = get_mutable_symbols(graph_evaluation_data) - return BUGSModel( + # Build initial model (without minimal cache keys precomputed) + model_without_min_keys = BUGSModel( model_def, data, g, @@ -296,6 +480,46 @@ function BUGSModel( mutable_symbols, nothing, ) + # Precompute marginalization order and minimal cache keys once + n = length(graph_evaluation_data.sorted_nodes) + sorted_indices = JuliaBUGS.Model._compute_marginalization_order(model_without_min_keys) + minimal_keys = JuliaBUGS.Model._precompute_minimal_cache_keys( + model_without_min_keys, sorted_indices + ) + # Attach cached order and keys to GraphEvaluationData + graph_evaluation_data_with_keys = GraphEvaluationData{ + typeof(graph_evaluation_data.node_function_vals), + typeof(graph_evaluation_data.loop_vars_vals), + }( + graph_evaluation_data.sorted_nodes, + graph_evaluation_data.sorted_parameters, + graph_evaluation_data.is_stochastic_vals, + graph_evaluation_data.is_observed_vals, + graph_evaluation_data.node_function_vals, + graph_evaluation_data.loop_vars_vals, + graph_evaluation_data.node_types, + graph_evaluation_data.is_discrete_finite_vals, + minimal_keys, + sorted_indices, + ) + + # Return final model with cached minimal keys + return BUGSModel( + model_def, + data, + g, + evaluation_env, + is_transformed, + UseGraph(), + untransformed_param_length, + transformed_param_length, + untransformed_var_lengths, + transformed_var_lengths, + graph_evaluation_data_with_keys, + log_density_computation_function, + mutable_symbols, + nothing, + ) end ## Model interface @@ -402,15 +626,33 @@ params_dict = getparams(Dict, model, custom_env) ``` """ function getparams(model::BUGSModel, evaluation_env=model.evaluation_env) - param_length = if model.transformed - model.transformed_param_length + # Determine which parameters to include based on evaluation mode + gd = model.graph_evaluation_data + param_vars = if model.evaluation_mode isa UseAutoMarginalization + # Only include continuous parameters when auto marginalizing + filter(gd.sorted_parameters) do vn + idx = findfirst(==(vn), gd.sorted_nodes) + idx !== nothing && gd.node_types[idx] == :continuous + end else - model.untransformed_param_length + gd.sorted_parameters + end + + # Compute total length for allocation + param_length = 0 + if model.transformed + for vn in param_vars + param_length += model.transformed_var_lengths[vn] + end + else + for vn in param_vars + param_length += model.untransformed_var_lengths[vn] + end end param_vals = Vector{Float64}(undef, param_length) pos = 1 - for v in model.graph_evaluation_data.sorted_parameters + for v in param_vars if !model.transformed val = AbstractPPL.get(evaluation_env, v) len = model.untransformed_var_lengths[v] @@ -441,7 +683,17 @@ function getparams( T::Type{<:AbstractDict}, model::BUGSModel, evaluation_env=model.evaluation_env ) d = T() - for v in model.graph_evaluation_data.sorted_parameters + gd = model.graph_evaluation_data + # Respect evaluation mode when selecting parameters + param_vars = if model.evaluation_mode isa UseAutoMarginalization + filter(gd.sorted_parameters) do vn + idx = findfirst(==(vn), gd.sorted_nodes) + idx !== nothing && gd.node_types[idx] == :continuous + end + else + gd.sorted_parameters + end + for v in param_vars value = AbstractPPL.get(evaluation_env, v) if !model.transformed d[v] = value @@ -499,17 +751,27 @@ model_with_generated_eval = set_evaluation_mode(model, UseGeneratedLogDensityFun ``` """ function set_evaluation_mode(model::BUGSModel, mode::EvaluationMode) - if isnothing(model.log_density_computation_function) - @warn( - "The model does not support generated log density function, the evaluation mode is set to `UseGraph`." - ) - mode = UseGraph() - elseif !model.transformed && mode isa UseGeneratedLogDensityFunction - error( - "Cannot use `UseGeneratedLogDensityFunction` with untransformed model. " * - "The generated log density function expects parameters in transformed (unconstrained) space. " * - "Please use `settrans(model, true)` before switching to generated log density mode.", - ) + if mode isa UseGeneratedLogDensityFunction + if isnothing(model.log_density_computation_function) + @warn( + "The model does not support generated log density function, the evaluation mode is set to `UseGraph`." + ) + mode = UseGraph() + elseif !model.transformed + error( + "Cannot use `UseGeneratedLogDensityFunction` with untransformed model. " * + "The generated log density function expects parameters in transformed (unconstrained) space. " * + "Please use `settrans(model, true)` before switching to generated log density mode.", + ) + end + elseif mode isa UseAutoMarginalization + if !model.transformed + error( + "Cannot use `UseAutoMarginalization` with untransformed model. " * + "Auto marginalization expects parameters in transformed (unconstrained) space. " * + "Please use `settrans(model, true)` before switching to auto marginalization mode.", + ) + end end return BangBang.setproperty!!(model, :evaluation_mode, mode) end diff --git a/JuliaBUGS/src/model/evaluation.jl b/JuliaBUGS/src/model/evaluation.jl index 3721ae28d..8505e0a8e 100644 --- a/JuliaBUGS/src/model/evaluation.jl +++ b/JuliaBUGS/src/model/evaluation.jl @@ -251,3 +251,571 @@ function evaluate_with_values!!( tempered_logjoint=logprior + temperature * loglikelihood, ) end + +# ====================== +# Marginalization Support +# ====================== + +""" + _get_stochastic_parents_indices(model::BUGSModel) + +Get the stochastic parents (through deterministic nodes) for each node in the model. +Returns a vector of index vectors aligned with sorted_nodes. +""" +function _get_stochastic_parents_indices(model::BUGSModel) + order = model.graph_evaluation_data.sorted_nodes + name_to_pos = Dict(order[i] => i for i in 1:length(order)) + is_stochastic = model.graph_evaluation_data.is_stochastic_vals + parents_idx = [Int[] for _ in 1:length(order)] + + for i in eachindex(order) + if is_stochastic[i] + # Use existing function to find stochastic parents through deterministic nodes + stochastic_parents, _ = JuliaBUGS.dfs_find_stochastic_boundary_and_deterministic_variables_en_route( + model.g, order[i], MetaGraphsNext.inneighbor_labels + ) + # Convert VarNames to indices + for parent_vn in stochastic_parents + if haskey(name_to_pos, parent_vn) + push!(parents_idx[i], name_to_pos[parent_vn]) + end + end + sort!(parents_idx[i]) # Keep sorted for stability + end + end + + return parents_idx +end + +""" + _precompute_minimal_cache_keys(model::BUGSModel, order::Vector{Int}) + +Precompute minimal cache keys for memoization during marginalization. +The frontier at each position should include all discrete finite variables that: +1. Have been processed (appear earlier in the evaluation order) +2. May affect the current computation +""" +function _precompute_minimal_cache_keys(model::BUGSModel, order::Vector{Int}) + gd = model.graph_evaluation_data + n = length(order) + is_stochastic = gd.is_stochastic_vals + is_observed = gd.is_observed_vals + is_discrete_finite = gd.is_discrete_finite_vals + node_types = gd.node_types + + # Get stochastic parents (stochastic boundary) for each node + parents_idx = _get_stochastic_parents_indices(model) + + # Build mapping from node index (in gd.sorted_nodes) -> position in the provided order. + # This lets us reason about liveness w.r.t. the chosen evaluation order. + order_pos = Vector{Int}(undef, length(gd.sorted_nodes)) + @inbounds for k in 1:n + order_pos[order[k]] = k + end + + # Compute last-use POSITIONS (w.r.t. 'order') for each unobserved finite-discrete variable. + # A variable stays in the frontier until we pass the last stochastic node + # (observed or unobserved) whose distribution depends on it. + last_use_pos = Dict{Int,Int}() # map from variable index -> last position in 'order' + for j_label in 1:length(gd.sorted_nodes) + if gd.is_stochastic_vals[j_label] + j_pos = order_pos[j_label] + for p_label in parents_idx[j_label] + if is_discrete_finite[p_label] && !is_observed[p_label] + # Default to the position of the variable itself if unseen + default_pos = order_pos[p_label] + last_use_pos[p_label] = max( + get(last_use_pos, p_label, default_pos), j_pos + ) + end + end + end + end + + # Initialize frontier keys for each position based on liveness + # Optimized incremental construction to avoid O(n^2) in common patterns + minimal_keys = Dict{Int,Vector{Int}}() + + # Precompute starts and ends in order positions + starts_at = Dict{Int,Vector{Int}}() + for lbl in 1:length(gd.sorted_nodes) + pos = order_pos[lbl] + if is_discrete_finite[lbl] && !is_observed[lbl] + push!(get!(starts_at, pos, Int[]), lbl) + end + end + + # Active set of earlier discrete finite variables (by label index) + active = Int[] + # Track end positions for active labels + function purge_expired!(active_vec::Vector{Int}, k_pos::Int) + # Remove any with last_use_pos < k_pos + i = 1 + while i <= length(active_vec) + lbl = active_vec[i] + if get(last_use_pos, lbl, 0) < k_pos + deleteat!(active_vec, i) + else + i += 1 + end + end + return active_vec + end + + for k in 1:n + # Add labels that start at previous position so they count as "earlier" + if haskey(starts_at, k - 1) + append!(active, starts_at[k - 1]) + end + # Drop any labels that have expired before current position + purge_expired!(active, k) + # Sort for stable key representation + sort!(active) + minimal_keys[order[k]] = copy(active) + end + + return minimal_keys +end + +""" + _compute_marginalization_order(model::BUGSModel) -> Vector{Int} + +Compute a topologically-valid evaluation order that reduces the frontier size +by placing discrete finite variables immediately before their observed dependents +whenever possible. This greatly reduces branching in the recursive enumerator. +""" +function _compute_marginalization_order(model::BUGSModel) + gd = model.graph_evaluation_data + n = length(gd.sorted_nodes) + + # Mapping VarName <-> index in sorted_nodes + order = gd.sorted_nodes + pos = Dict(order[i] => i for i in 1:n) + + # Direct parents via graph (for topo validity) + function parents(vn) + return collect(MetaGraphsNext.inneighbor_labels(model.g, vn)) + end + + # Keep track of which nodes are placed + placed = fill(false, n) + out = Int[] + + # Recursive placer that ensures all parents are placed first + function place_with_dependencies(vn::VarName) + i = pos[vn] + if placed[i] + return nothing + end + # Place all direct parents first + for p in parents(vn) + place_with_dependencies(p) + end + push!(out, i) + placed[i] = true + end + + # Identify observed stochastic nodes and their discrete-finite parents (via stochastic boundary) + # We use the existing helper to traverse through deterministic nodes + stoch_parents = _get_stochastic_parents_indices(model) + + # First, for each observed stochastic node, place its discrete-finite parents + # (and dependencies) immediately before placing the node itself. + for (i, vn) in enumerate(order) + if gd.is_stochastic_vals[i] && gd.is_observed_vals[i] + # Place discrete-finite unobserved parents (by label index -> VarName) + for pidx in stoch_parents[i] + if gd.is_discrete_finite_vals[pidx] && !gd.is_observed_vals[pidx] + place_with_dependencies(order[pidx]) + end + end + # Then place the observed node itself (ensures mu/sigma/etc. also placed) + place_with_dependencies(vn) + end + end + + # Finally, place any remaining nodes in topological order + for vn in order + if !placed[pos[vn]] + place_with_dependencies(vn) + end + end + + return out +end + +""" + _marginalize_recursive(model, env, remaining_indices, parameter_values, param_idx, + var_lengths, memo, minimal_keys) + +Recursively compute log probability by marginalizing over discrete finite variables. +""" +function _marginalize_recursive( + model::BUGSModel, + env::NamedTuple, + remaining_indices::AbstractVector{Int}, + parameter_values::AbstractVector, + param_offsets::Dict{VarName,Int}, + var_lengths::Dict{VarName,Int}, + memo::Dict{Tuple{Int,Tuple,Tuple},Any}, + minimal_keys, +) + # Base case: no more nodes to process + if isempty(remaining_indices) + return 0.0, 0.0 + end + + current_idx = remaining_indices[1] + current_vn = model.graph_evaluation_data.sorted_nodes[current_idx] + + # Create memo key using minimal frontier + # Get the discrete finite frontier indices for this position (already sorted) + discrete_frontier_indices = get(minimal_keys, current_idx, Int[]) + + # Extract values only for discrete finite frontier variables + if !isempty(discrete_frontier_indices) + # These are discrete values set by enumeration, no AD wrapping + frontier_values = [ + AbstractPPL.get(env, model.graph_evaluation_data.sorted_nodes[idx]) for + idx in discrete_frontier_indices + ] + frontier_indices_tuple = Tuple(discrete_frontier_indices) + frontier_values_tuple = Tuple(frontier_values) + else + frontier_indices_tuple = () + frontier_values_tuple = () + end + # With parameter access keyed by variable name, results depend only on the + # current node and the discrete frontier state. Continuous parameters are + # global and constant for a given input vector. + memo_key = (current_idx, frontier_indices_tuple, frontier_values_tuple) + + if haskey(memo, memo_key) + return memo[memo_key] + end + + is_stochastic = model.graph_evaluation_data.is_stochastic_vals[current_idx] + is_observed = model.graph_evaluation_data.is_observed_vals[current_idx] + is_discrete_finite = model.graph_evaluation_data.is_discrete_finite_vals[current_idx] + node_function = model.graph_evaluation_data.node_function_vals[current_idx] + loop_vars = model.graph_evaluation_data.loop_vars_vals[current_idx] + + result_prior = 0.0 + result_lik = 0.0 + + if !is_stochastic + # Deterministic node + value = node_function(env, loop_vars) + new_env = BangBang.setindex!!(env, value, current_vn) + result_prior, result_lik = _marginalize_recursive( + model, + new_env, + @view(remaining_indices[2:end]), + parameter_values, + param_offsets, + var_lengths, + memo, + minimal_keys, + ) + + elseif is_observed + # Observed stochastic node + dist = node_function(env, loop_vars) + obs_value = AbstractPPL.get(env, current_vn) + obs_logp = logpdf(dist, obs_value) + + # Handle NaN values + if isnan(obs_logp) + obs_logp = -Inf + end + + rest_prior, rest_lik = _marginalize_recursive( + model, + env, + @view(remaining_indices[2:end]), + parameter_values, + param_offsets, + var_lengths, + memo, + minimal_keys, + ) + result_prior = rest_prior + result_lik = obs_logp + rest_lik + + elseif is_discrete_finite + # Discrete finite unobserved node - marginalize out + dist = node_function(env, loop_vars) + possible_values = enumerate_discrete_values(dist) + + total_logpriors = nothing + branch_loglikelihoods = nothing + + for (i, value) in enumerate(possible_values) + branch_env = BangBang.setindex!!(env, value, current_vn) + + value_logp = logpdf(dist, value) + if isnan(value_logp) + value_logp = -Inf + end + + branch_prior, branch_lik = _marginalize_recursive( + model, + branch_env, + @view(remaining_indices[2:end]), + parameter_values, + param_offsets, + var_lengths, + memo, + minimal_keys, + ) + + total_val = value_logp + branch_prior + lik_val = branch_lik + if total_logpriors === nothing + total_logpriors = Vector{typeof(total_val)}(undef, length(possible_values)) + branch_loglikelihoods = Vector{typeof(lik_val)}(undef, length(possible_values)) + end + total_logpriors[i] = total_val + branch_loglikelihoods[i] = lik_val + end + + @assert total_logpriors !== nothing && branch_loglikelihoods !== nothing + log_prior_total = LogExpFunctions.logsumexp(total_logpriors) + log_joint_total = LogExpFunctions.logsumexp(total_logpriors .+ branch_loglikelihoods) + if isfinite(log_prior_total) + result_prior = log_prior_total + result_lik = log_joint_total - log_prior_total + else + result_prior = log_prior_total + result_lik = log_joint_total + end + + else + # Continuous or discrete infinite unobserved node - use parameter values + dist = node_function(env, loop_vars) + b = Bijectors.bijector(dist) + + if !haskey(var_lengths, current_vn) + error( + "Missing transformed length for variable '$(current_vn)'. " * + "All variables should have their transformed lengths pre-computed.", + ) + end + + l = var_lengths[current_vn] + # Fetch the start position for this variable from the precomputed map + start_idx = get(param_offsets, current_vn, 0) + if start_idx == 0 + error("Missing parameter offset for variable '$(current_vn)'.") + end + if start_idx + l - 1 > length(parameter_values) + error( + "Parameter index out of bounds: needed $(start_idx + l - 1) elements, " * + "but parameter_values has only $(length(parameter_values)) elements.", + ) + end + + b_inv = Bijectors.inverse(b) + param_slice = view(parameter_values, start_idx:(start_idx + l - 1)) + + reconstructed_value = reconstruct(b_inv, dist, param_slice) + value, logjac = Bijectors.with_logabsdet_jacobian(b_inv, reconstructed_value) + + new_env = BangBang.setindex!!(env, value, current_vn) + + dist_logp = logpdf(dist, value) + if isnan(dist_logp) + dist_logp = -Inf + else + dist_logp += logjac + end + + rest_prior, rest_lik = _marginalize_recursive( + model, + new_env, + @view(remaining_indices[2:end]), + parameter_values, + param_offsets, + var_lengths, + memo, + minimal_keys, + ) + + result_prior = dist_logp + rest_prior + result_lik = rest_lik + end + + memo[memo_key] = (result_prior, result_lik) + return result_prior, result_lik +end + +""" + evaluate_with_marginalization_rng!!( + rng::Random.AbstractRNG, + model::BUGSModel; + temperature=1.0, + transformed=true + ) + +Evaluate model using marginalization for discrete finite variables and sampling for others. +""" +function evaluate_with_marginalization_rng!!( + rng::Random.AbstractRNG, model::BUGSModel; temperature=1.0, transformed=true +) + if !transformed + error( + "Auto marginalization only supports transformed (unconstrained) parameter space. " * + "Please use transformed=true.", + ) + end + + # For RNG-based evaluation, we don't marginalize - we sample discrete variables + # This is similar to evaluate_with_rng!! but could be extended for hybrid approaches + return evaluate_with_rng!!( + rng, model; sample_all=true, temperature=temperature, transformed=transformed + ) +end + +""" + evaluate_with_marginalization_env!!( + model::BUGSModel, + evaluation_env=smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols); + temperature=1.0, + transformed=true + ) + +Evaluate model using marginalization for discrete finite variables. +""" +function evaluate_with_marginalization_env!!( + model::BUGSModel, + evaluation_env=smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols); + temperature=1.0, + transformed=true, +) + if !transformed + error( + "Auto marginalization only supports transformed (unconstrained) parameter space. " * + "Please use transformed=true.", + ) + end + + # For environment-based evaluation without explicit parameter values, + # we need to extract ONLY continuous parameters for marginalization + gd = model.graph_evaluation_data + param_values = Float64[] + + for vn in gd.sorted_parameters + idx = findfirst(==(vn), gd.sorted_nodes) + if idx !== nothing && gd.node_types[idx] == :continuous + value = AbstractPPL.get(evaluation_env, vn) + if transformed + # Transform to unconstrained space + (; node_function, loop_vars) = model.g[vn] + dist = node_function(evaluation_env, loop_vars) + transformed_value = Bijectors.transform(Bijectors.bijector(dist), value) + if transformed_value isa AbstractArray + append!(param_values, vec(transformed_value)) + else + push!(param_values, transformed_value) + end + else + if value isa AbstractArray + append!(param_values, vec(value)) + else + push!(param_values, value) + end + end + end + end + + return evaluate_with_marginalization_values!!( + model, param_values; temperature=temperature, transformed=transformed + ) +end + +""" + evaluate_with_marginalization_values!!( + model::BUGSModel, + flattened_values::AbstractVector; + temperature=1.0, + transformed=true + ) + +Evaluate model with marginalization over discrete finite variables. +""" +function evaluate_with_marginalization_values!!( + model::BUGSModel, flattened_values::AbstractVector; temperature=1.0, transformed=true +) + if !transformed + error( + "Auto marginalization only supports transformed (unconstrained) parameter space. " * + "Please use transformed=true.", + ) + end + + # Use cached marginalization order and minimal frontier keys when available + gd = model.graph_evaluation_data + n = length(gd.sorted_nodes) + # Strictly require caches to be present for performance + if isempty(gd.marginalization_order) || isempty(gd.minimal_cache_keys) + error( + "Auto marginalization cache missing. This model was not prepared for UseAutoMarginalization.", + ) + end + sorted_indices = gd.marginalization_order + minimal_keys = gd.minimal_cache_keys + + # Initialize memoization cache + # Size hint: at most 2^|discrete_finite| * |nodes| entries + n_discrete_finite = sum(model.graph_evaluation_data.is_discrete_finite_vals) + expected_entries = if n_discrete_finite > 20 + 1_000_000 # Cap at 1M for large problems + else + min((1 << n_discrete_finite) * n, 1_000_000) + end + memo = Dict{Tuple{Int,Tuple,Tuple},Any}() + sizehint!(memo, expected_entries) + + # Start recursive evaluation + evaluation_env = smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols) + + # For marginalization, only continuous parameters need var_lengths + # Discrete finite variables are marginalized over, not sampled + var_lengths = Dict{VarName,Int}() + continuous_param_order = VarName[] + for vn in gd.sorted_parameters + idx = findfirst(==(vn), gd.sorted_nodes) + if idx !== nothing && gd.node_types[idx] == :continuous + push!(continuous_param_order, vn) + var_lengths[vn] = model.transformed_var_lengths[vn] + end + end + + # Build mapping from variable -> start index in flattened_values + param_offsets = Dict{VarName,Int}() + start = 1 + for vn in continuous_param_order + param_offsets[vn] = start + start += var_lengths[vn] + end + + log_prior, log_likelihood = _marginalize_recursive( + model, + evaluation_env, + sorted_indices, + flattened_values, + param_offsets, + var_lengths, + memo, + minimal_keys, + ) + + # For consistency with other evaluate functions, we return the environment + # and split the log probability (though marginalization combines them) + return evaluation_env, + ( + logprior=log_prior, + loglikelihood=log_likelihood, + tempered_logjoint=log_prior + temperature * log_likelihood, + ) +end diff --git a/JuliaBUGS/src/model/logdensityproblems.jl b/JuliaBUGS/src/model/logdensityproblems.jl index 07d82b018..f8738a8ef 100644 --- a/JuliaBUGS/src/model/logdensityproblems.jl +++ b/JuliaBUGS/src/model/logdensityproblems.jl @@ -9,15 +9,45 @@ function _eval_logdensity(model, ::UseGraph, x) return logp end +function _eval_logdensity(model, ::UseAutoMarginalization, x) + _, log_densities = evaluate_with_marginalization_values!!(model, x; transformed=true) + return log_densities.tempered_logjoint +end + function LogDensityProblems.logdensity(model::BUGSModel, x::AbstractArray) return _eval_logdensity(model, model.evaluation_mode, x) end function LogDensityProblems.dimension(model::BUGSModel) - return if model.transformed - model.transformed_param_length + # For auto marginalization, only count continuous parameters + if model.evaluation_mode isa UseAutoMarginalization + continuous_param_length = 0 + for (i, vn) in enumerate(model.graph_evaluation_data.sorted_parameters) + idx = findfirst(==(vn), model.graph_evaluation_data.sorted_nodes) + if idx !== nothing + node_type = model.graph_evaluation_data.node_types[idx] + # Only include continuous variables (exclude all discrete) + if node_type == :continuous + if model.transformed + continuous_param_length += model.transformed_var_lengths[vn] + else + continuous_param_length += model.untransformed_var_lengths[vn] + end + elseif node_type == :discrete_infinite + error( + "Model contains discrete infinite variable $(vn) which cannot be marginalized. " * + "Use UseGraph evaluation mode instead.", + ) + end + end + end + return continuous_param_length else - model.untransformed_param_length + return if model.transformed + model.transformed_param_length + else + model.untransformed_param_length + end end end diff --git a/JuliaBUGS/test/model/auto_marginalization.jl b/JuliaBUGS/test/model/auto_marginalization.jl new file mode 100644 index 000000000..11b9420d9 --- /dev/null +++ b/JuliaBUGS/test/model/auto_marginalization.jl @@ -0,0 +1,829 @@ +# Tests for auto-marginalization of discrete finite variables +# This file is included from runtests.jl which provides all necessary imports + +using JuliaBUGS: @bugs, compile, settrans, initialize!, getparams +using JuliaBUGS.Model: + set_evaluation_mode, + UseAutoMarginalization, + UseGraph, + evaluate_with_marginalization_values!! + +@testset "Auto-Marginalization" begin + println("[AutoMargTest] Starting Auto-Marginalization test suite..."); + flush(stdout) + # HMM helper function for ground truth using forward algorithm + function forward_algorithm_hmm(y, mu1, mu2, sigma, pi, trans) + T = length(y) + n_states = 2 + alpha = zeros(n_states, T) + + for s in 1:n_states + mu_s = s == 1 ? mu1 : mu2 + alpha[s, 1] = log(pi[s]) + logpdf(Normal(mu_s, sigma), y[1]) + end + + for t in 2:T + for s in 1:n_states + mu_s = s == 1 ? mu1 : mu2 + log_trans_probs = [ + alpha[s_prev, t - 1] + log(trans[s_prev, s]) for s_prev in 1:n_states + ] + alpha[s, t] = + LogExpFunctions.logsumexp(log_trans_probs) + + logpdf(Normal(mu_s, sigma), y[t]) + end + end + + return LogExpFunctions.logsumexp(alpha[:, T]) + end + + @testset "Simple HMM with fixed parameters" begin + println("[AutoMargTest] HMM (fixed params): compiling..."); + flush(stdout) + # HMM with fixed emission parameters (no continuous parameters to estimate) + hmm_fixed_def = @bugs begin + mu[1] = 0.0 + mu[2] = 5.0 + sigma = 1.0 + + trans[1, 1] = 0.7 + trans[1, 2] = 0.3 + trans[2, 1] = 0.4 + trans[2, 2] = 0.6 + + pi[1] = 0.5 + pi[2] = 0.5 + + z[1] ~ Categorical(pi[1:2]) + for t in 2:T + p[t, 1] = trans[z[t - 1], 1] + p[t, 2] = trans[z[t - 1], 2] + z[t] ~ Categorical(p[t, :]) + end + + for t in 1:T + y[t] ~ Normal(mu[z[t]], sigma) + end + end + + T = 2 + y_obs = [0.1, 4.9] + data = (T=T, y=y_obs) + + model = compile(hmm_fixed_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + println("[AutoMargTest] HMM (fixed params): evaluating logdensity..."); + flush(stdout) + + # No continuous parameters, so empty array + x_empty = Float64[] + logp_marginalized = LogDensityProblems.logdensity(model, x_empty) + + # Expected value (manual calculation) + expected = -3.744970426679133 + + @test isapprox(logp_marginalized, expected; atol=1e-6) + end + + @testset "HMM with continuous parameters" begin + # HMM where emission means and variance are parameters to be estimated + hmm_param_def = @bugs begin + # Priors for emission parameters + mu[1] ~ Normal(0, 10) + mu[2] ~ Normal(5, 10) + sigma ~ Exponential(1) + + # Fixed transition matrix + trans[1, 1] = 0.7 + trans[1, 2] = 0.3 + trans[2, 1] = 0.4 + trans[2, 2] = 0.6 + + # Initial state probabilities + pi[1] = 0.5 + pi[2] = 0.5 + + # Hidden states (discrete, to be marginalized) + z[1] ~ Categorical(pi[1:2]) + for t in 2:T + p[t, 1] = trans[z[t - 1], 1] + p[t, 2] = trans[z[t - 1], 2] + z[t] ~ Categorical(p[t, :]) + end + + # Observations + for t in 1:T + y[t] ~ Normal(mu[z[t]], sigma) + end + end + + @testset "T=$T" for T in [2, 3, 4, 5] + println("[AutoMargTest] HMM (params): T=$(T) compile+eval..."); + flush(stdout) + y_obs = if T == 2 + [0.1, 4.9] + elseif T == 3 + [0.1, 4.9, 5.1] + elseif T == 4 + [0.1, 4.9, 5.1, -0.2] + else # T == 5 + [0.1, 4.9, 5.1, -0.2, 5.0] + end + + data = (T=T, y=y_obs) + + model = compile(hmm_param_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + # Check dimension - should be 3 (sigma, mu[2], mu[1]) + @test LogDensityProblems.dimension(model) == 3 + + # Test with specific continuous parameters + # Order: sigma, mu[2], mu[1] (from sorted_parameters) + test_params = [0.0, 5.0, 0.0] # log(sigma)=0 -> sigma=1, mu[2]=5, mu[1]=0 + + logp_marginalized = LogDensityProblems.logdensity(model, test_params) + println("[AutoMargTest] HMM (params): T=$(T) logdensity done"); + flush(stdout) + + # Compute expected value using forward algorithm + pi_vals = [0.5, 0.5] + trans_mat = [0.7 0.3; 0.4 0.6] + logp_forward = forward_algorithm_hmm(y_obs, 0.0, 5.0, 1.0, pi_vals, trans_mat) + + # Add prior log probabilities + prior_logp = + logpdf(Normal(0, 10), 0.0) + + logpdf(Normal(5, 10), 5.0) + + logpdf(Exponential(1), 1.0) + expected = logp_forward + prior_logp + + @test isapprox(logp_marginalized, expected; atol=1e-10) + end + end + + @testset "Marginalization mode consistency" begin + # Test that UseAutoMarginalization correctly filters parameters + hmm_def = @bugs begin + mu[1] ~ Normal(0, 10) + mu[2] ~ Normal(5, 10) + sigma ~ Exponential(1) + + trans[1, 1] = 0.7 + trans[1, 2] = 0.3 + trans[2, 1] = 0.4 + trans[2, 2] = 0.6 + + pi[1] = 0.5 + pi[2] = 0.5 + + z[1] ~ Categorical(pi[1:2]) + for t in 2:T + p[t, 1] = trans[z[t - 1], 1] + p[t, 2] = trans[z[t - 1], 2] + z[t] ~ Categorical(p[t, :]) + end + + for t in 1:T + y[t] ~ Normal(mu[z[t]], sigma) + end + end + + T = 3 + data = (T=T, y=[0.1, 4.9, 5.1]) + + # Create model in graph mode + model_graph = compile(hmm_def, data) + model_graph = settrans(model_graph, true) + model_graph = set_evaluation_mode(model_graph, UseGraph()) + + # Create model in marginalization mode + model_marg = compile(hmm_def, data) + model_marg = settrans(model_marg, true) + model_marg = set_evaluation_mode(model_marg, UseAutoMarginalization()) + + # Graph mode should include discrete parameters + @test LogDensityProblems.dimension(model_graph) == 6 # z[1:3] + sigma + mu[2] + mu[1] + + # Marginalization mode should only include continuous parameters + @test LogDensityProblems.dimension(model_marg) == 3 # sigma + mu[2] + mu[1] + + # Check that discrete finite variables are correctly identified + gd = model_marg.graph_evaluation_data + discrete_count = sum(gd.is_discrete_finite_vals) + @test discrete_count == 3 # z[1], z[2], z[3] + end + + @testset "Gaussian Mixture Models" begin + println("[AutoMargTest] GMM tests: start..."); + flush(stdout) + # Helper function for ground truth mixture likelihood + function mixture_loglikelihood(y, weights, mus, sigmas) + n = length(y) + k = length(weights) + logp_total = 0.0 + + for i in 1:n + # Log-sum-exp over components for each observation + log_probs = zeros(k) + for j in 1:k + log_probs[j] = log(weights[j]) + logpdf(Normal(mus[j], sigmas[j]), y[i]) + end + logp_total += LogExpFunctions.logsumexp(log_probs) + end + + return logp_total + end + + @testset "Two-component mixture with fixed weights" begin + println("[AutoMargTest] GMM K=2 correctness..."); + flush(stdout) + # Simple mixture with fixed mixture weights + mixture_fixed_def = @bugs begin + # Fixed mixture weights + w[1] = 0.3 + w[2] = 0.7 + + # Component parameters + mu[1] ~ Normal(-2, 5) + mu[2] ~ Normal(2, 5) + sigma[1] ~ Exponential(1) + sigma[2] ~ Exponential(1) + + # Component assignments (discrete, to be marginalized) + for i in 1:N + z[i] ~ Categorical(w[1:2]) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end + + N = 4 + y_obs = [-1.5, 2.3, -2.1, 1.8] + data = (N=N, y=y_obs) + + model = compile(mixture_fixed_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + # Should have 4 continuous parameters: sigma[1], sigma[2], mu[2], mu[1] + @test LogDensityProblems.dimension(model) == 4 + + # Test with specific parameters + # Order: log(sigma[1]), log(sigma[2]), mu[2], mu[1] + test_params = [0.0, 0.0, 2.0, -2.0] # sigmas=1, mu[2]=2, mu[1]=-2 + + logp_marginalized = LogDensityProblems.logdensity(model, test_params) + + # Compute expected value + weights = [0.3, 0.7] + mus = [-2.0, 2.0] + sigmas = [1.0, 1.0] + + logp_likelihood = mixture_loglikelihood(y_obs, weights, mus, sigmas) + prior_logp = + logpdf(Normal(-2, 5), -2.0) + + logpdf(Normal(2, 5), 2.0) + + logpdf(Exponential(1), 1.0) + + logpdf(Exponential(1), 1.0) + expected = logp_likelihood + prior_logp + + @test isapprox(logp_marginalized, expected; atol=1e-10) + end + + @testset "Three-component mixture with fixed weights" begin + println("[AutoMargTest] GMM K=3 correctness..."); + flush(stdout) + # Extend to 3 components with exact verification + mixture_3comp_def = @bugs begin + # Fixed mixture weights + w[1] = 0.2 + w[2] = 0.5 + w[3] = 0.3 + + # Component parameters + mu[1] ~ Normal(-3, 5) + mu[2] ~ Normal(0, 5) + mu[3] ~ Normal(3, 5) + for k in 1:3 + sigma[k] ~ Exponential(1) + end + + # Component assignments + for i in 1:N + z[i] ~ Categorical(w[1:3]) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end + + N = 3 + y_obs = [-2.5, 0.5, 3.2] + data = (N=N, y=y_obs) + + model = compile(mixture_3comp_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + # Should have 6 continuous parameters: 3 sigmas + 3 mus + @test LogDensityProblems.dimension(model) == 6 + + # Test with specific parameters + test_params = [0.0, 0.0, 0.0, 3.0, 0.0, -3.0] + # log(sigmas)=0 -> all sigmas=1, mu[3]=3, mu[2]=0, mu[1]=-3 + + logp_marginalized = LogDensityProblems.logdensity(model, test_params) + + # Compute expected value + weights = [0.2, 0.5, 0.3] + mus = [-3.0, 0.0, 3.0] + sigmas = [1.0, 1.0, 1.0] + + logp_likelihood = mixture_loglikelihood(y_obs, weights, mus, sigmas) + prior_logp = sum([ + logpdf(Normal(-3, 5), -3.0), + logpdf(Normal(0, 5), 0.0), + logpdf(Normal(3, 5), 3.0), + logpdf(Exponential(1), 1.0), + logpdf(Exponential(1), 1.0), + logpdf(Exponential(1), 1.0), + ]) + expected = logp_likelihood + prior_logp + + @test isapprox(logp_marginalized, expected; atol=1e-10) + end + + @testset "Label invariance" begin + println("[AutoMargTest] GMM label invariance..."); + flush(stdout) + # Verify that permuting component labels doesn't change log-density + # when weights are equal + mixture_sym_def = @bugs begin + w[1] = 0.5 + w[2] = 0.5 + + mu[1] ~ Normal(0, 10) + mu[2] ~ Normal(0, 10) + sigma[1] ~ Exponential(1) + sigma[2] ~ Exponential(1) + + for i in 1:N + z[i] ~ Categorical(w[1:2]) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end + + N = 4 + y_obs = [1.0, 2.0, -1.0, 3.0] + data = (N=N, y=y_obs) + + model = compile(mixture_sym_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + # Test with original ordering + # Order: log(sigma[1]), log(sigma[2]), mu[2], mu[1] + params1 = [-0.5, 0.0, 3.0, 1.0] # sigma[1]=exp(-0.5), sigma[2]=1, mu[2]=3, mu[1]=1 + logp1 = LogDensityProblems.logdensity(model, params1) + + # Test with swapped components (swap mu and sigma values) + params2 = [0.0, -0.5, 1.0, 3.0] # sigma[1]=1, sigma[2]=exp(-0.5), mu[2]=1, mu[1]=3 + logp2 = LogDensityProblems.logdensity(model, params2) + + # The log probabilities should be equal due to symmetry + # (swapping components 1 and 2 completely with equal weights) + @test isapprox(logp1, logp2; atol=1e-10) + end + + @testset "Partial observation of z" begin + # Some z[i] are observed, others are marginalized + mixture_partial_def = @bugs begin + w[1] = 0.3 + w[2] = 0.7 + + mu[1] ~ Normal(-2, 5) + mu[2] ~ Normal(2, 5) + sigma ~ Exponential(1) # Shared sigma + + for i in 1:N + z[i] ~ Categorical(w[1:2]) + y[i] ~ Normal(mu[z[i]], sigma) + end + end + + N = 4 + # Observe z[1] and z[3], marginalize z[2] and z[4] + data = (N=N, y=[1.0, 2.0, -1.0, 3.0], z=[2, missing, 1, missing]) + + model = compile(mixture_partial_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + # Should have 3 continuous parameters: sigma, mu[2], mu[1] + # z[2] and z[4] are marginalized out + @test LogDensityProblems.dimension(model) == 3 + + # Test evaluation + test_params = [0.0, 2.0, -2.0] # log(sigma)=0->sigma=1, mu[2]=2, mu[1]=-2 + logp = LogDensityProblems.logdensity(model, test_params) + + # Verify it's finite and reasonable + @test isfinite(logp) + @test logp < 0 + + # Manually compute expected for observed components + # z[1]=2 -> y[1]=1.0 comes from mu[2]=2 + # z[3]=1 -> y[3]=-1.0 comes from mu[1]=-2 + # z[2] and z[4] are marginalized + sigma_val = 1.0 + mu_vals = [-2.0, 2.0] + weights = [0.3, 0.7] + + # Observed parts + logp_obs = ( + log(weights[2]) + + logpdf(Normal(mu_vals[2], sigma_val), 1.0) + # z[1]=2, y[1]=1.0 + log(weights[1]) + + logpdf(Normal(mu_vals[1], sigma_val), -1.0) # z[3]=1, y[3]=-1.0 + ) + + # Marginalized parts for y[2]=2.0 and y[4]=3.0 + logp_marg2 = LogExpFunctions.logsumexp([ + log(weights[1]) + logpdf(Normal(mu_vals[1], sigma_val), 2.0), + log(weights[2]) + logpdf(Normal(mu_vals[2], sigma_val), 2.0), + ]) + logp_marg4 = LogExpFunctions.logsumexp([ + log(weights[1]) + logpdf(Normal(mu_vals[1], sigma_val), 3.0), + log(weights[2]) + logpdf(Normal(mu_vals[2], sigma_val), 3.0), + ]) + + logp_likelihood = logp_obs + logp_marg2 + logp_marg4 + prior_logp = ( + logpdf(Normal(-2, 5), -2.0) + + logpdf(Normal(2, 5), 2.0) + + logpdf(Exponential(1), 1.0) + ) + expected = logp_likelihood + prior_logp + + @test isapprox(logp, expected; atol=1e-10) + end + + @testset "Mixture with Dirichlet prior on weights" begin + # More realistic mixture with learned weights + mixture_dirichlet_def = @bugs begin + # Mixture weights with Dirichlet prior + alpha[1] = 1.0 + alpha[2] = 1.0 + alpha[3] = 1.0 + w[1:3] ~ ddirich(alpha[1:3]) + + # Component parameters + for k in 1:3 + mu[k] ~ Normal(0, 10) + sigma[k] ~ Exponential(1) + end + + # Component assignments + for i in 1:N + z[i] ~ Categorical(w[1:3]) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end + + N = 5 + y_obs = [-3.0, 0.1, 2.9, -2.8, 3.1] + data = (N=N, y=y_obs) + + model = compile(mixture_dirichlet_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + # Should have 8 continuous parameters: + # 3 sigmas + 3 mus + 2 transformed weight components (3-1 due to simplex constraint) + @test LogDensityProblems.dimension(model) == 8 + + # Test with specific parameters + # Simplex transform for weights [0.2, 0.3, 0.5] + # Using stick-breaking: w1=0.2, w2=0.3, w3=0.5 + # This requires specific transformed values + w_target = [0.2, 0.3, 0.5] + # For Dirichlet, use log-ratio transform + log_ratios = [log(w_target[1] / w_target[3]), log(w_target[2] / w_target[3])] + + test_params = [ + 0.0, + 0.0, + 0.0, # log(sigmas) = 0 -> all sigmas = 1 + 3.0, + 0.0, + -3.0, # mu[3]=3, mu[2]=0, mu[1]=-3 + log_ratios[1], + log_ratios[2], # transformed weights + ] + + logp_marginalized = LogDensityProblems.logdensity(model, test_params) + + # Verify it's finite and reasonable + @test isfinite(logp_marginalized) + @test logp_marginalized < 0 # Should be negative for realistic parameters + end + + @testset "Hierarchical mixture model" begin + # Mixture with hierarchical structure on component means + hierarchical_mixture_def = @bugs begin + # Hyperpriors + mu_global ~ Normal(0, 10) + tau_global ~ Exponential(1) + + # Mixture weights + w[1] = 0.5 + w[2] = 0.5 + + # Component-specific parameters with hierarchical prior + for k in 1:2 + mu[k] ~ Normal(mu_global, tau_global) + sigma[k] ~ Exponential(1) + end + + # Data generation + for i in 1:N + z[i] ~ Categorical(w[1:2]) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end + + N = 6 + y_obs = [1.0, 1.2, 4.8, 5.1, 0.9, 5.0] + data = (N=N, y=y_obs) + + model = compile(hierarchical_mixture_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + # Should have 6 continuous parameters: + # mu_global, tau_global, 2 sigmas, 2 mus + @test LogDensityProblems.dimension(model) == 6 + + # Test evaluation with multiple parameter sets + # Test 1: Parameters that should give reasonable likelihood + test_params = [3.0, 0.0, 0.0, 0.0, 5.0, 1.0] + # mu_global=3, log(tau_global)=0->tau=1, log(sigmas)=0->sigmas=1, mu[2]=5, mu[1]=1 + + logp_marginalized = LogDensityProblems.logdensity(model, test_params) + + # Verify the result is finite and reasonable + @test isfinite(logp_marginalized) + @test logp_marginalized < 0 # Log probability should be negative + + # Test 2: Different parameters - should give different likelihood + test_params2 = [2.5, -0.5, -0.5, 0.2, 4.5, 0.5] + logp_marginalized2 = LogDensityProblems.logdensity(model, test_params2) + + @test isfinite(logp_marginalized2) + @test logp_marginalized2 != logp_marginalized # Different params should give different results + + # Test 3: Verify multiple evaluations are consistent + logp_repeat = LogDensityProblems.logdensity(model, test_params) + @test logp_repeat == logp_marginalized # Same params should give same result + end + end + + @testset "Edge cases" begin + @testset "Model with no discrete finite variables" begin + # Simple continuous model - marginalization should work but do nothing special + continuous_def = @bugs begin + mu ~ Normal(0, 10) + sigma ~ Exponential(1) + for i in 1:N + y[i] ~ Normal(mu, sigma) + end + end + + N = 3 + data = (N=N, y=[1.0, 2.0, 3.0]) + + model = compile(continuous_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + @test LogDensityProblems.dimension(model) == 2 # mu, sigma + + # Should work normally + test_params = [2.0, 0.0] # mu=2, log(sigma)=0 -> sigma=1 + logp = LogDensityProblems.logdensity(model, test_params) + @test isfinite(logp) + end + + @testset "Model with observed discrete variables" begin + # Discrete variables that are observed should not be marginalized + observed_discrete_def = @bugs begin + p ~ Beta(1, 1) + for i in 1:N + x[i] ~ Bernoulli(p) + end + end + + N = 5 + data = (N=N, x=[1, 0, 1, 1, 0]) + + model = compile(observed_discrete_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + @test LogDensityProblems.dimension(model) == 1 # Only p + + # Test evaluation + test_params = [0.0] # logit(p) = 0 -> p = 0.5 + logp = LogDensityProblems.logdensity(model, test_params) + + # Expected in transformed space: + # - Beta(1,1) prior at p=0.5: log(1) = 0 + # - Likelihood: 3 successes and 2 failures with p=0.5: 5*log(0.5) + # - Log Jacobian for logit transform at p=0.5: log(p*(1-p)) = log(0.25) + p_val = 0.5 + expected = + logpdf(Beta(1, 1), p_val) + + 3 * log(p_val) + + 2 * log(1 - p_val) + + log(p_val * (1 - p_val)) # Jacobian + @test isapprox(logp, expected; atol=1e-10) + end + end + + @testset "Gradient vs finite differences (GMM)" begin + println("[AutoMargTest] GMM gradients: compiling..."); + flush(stdout) + # Two-component mixture with fixed weights; params: mu[1:2], sigma[1:2] + mixture_def = @bugs begin + w[1] = 0.3 + w[2] = 0.7 + mu[1] ~ Normal(-2, 5) + mu[2] ~ Normal(2, 5) + sigma[1] ~ Exponential(1) + sigma[2] ~ Exponential(1) + for i in 1:N + z[i] ~ Categorical(w[1:2]) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end + + N = 6 + y = [-1.8, -2.2, 1.9, 2.1, -1.5, 2.4] + data = (N=N, y=y) + model = compile(mixture_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + # Initialize model and extract parameter vector + initialize!(model, (; mu=[-2.0, 2.0], sigma=[1.1, 0.9])) + θ = getparams(model) + + # AD gradient via ForwardDiff + ad_model = ADgradient(AutoForwardDiff(), model) + println("[AutoMargTest] GMM gradients: AD gradient..."); + flush(stdout) + val_ad, grad_ad = LogDensityProblems.logdensity_and_gradient(ad_model, θ) + println("[AutoMargTest] GMM gradients: AD gradient done"); + flush(stdout) + + # Central finite differences + function f(θ) + LogDensityProblems.logdensity(model, θ) + end + ϵ = 1e-6 + grad_fd = similar(θ) + println("[AutoMargTest] GMM gradients: FD gradient..."); + flush(stdout) + for i in eachindex(θ) + e = zeros(length(θ)); + e[i] = 1.0 + fp = f(θ .+ ϵ .* e) + fm = f(θ .- ϵ .* e) + grad_fd[i] = (fp - fm) / (2ϵ) + println("[AutoMargTest] GMM gradients: FD step ", i, "/", length(θ)); + flush(stdout) + end + println("[AutoMargTest] GMM gradients: FD gradient done"); + flush(stdout) + + rel_err = maximum(abs.(grad_ad .- grad_fd) ./ (abs.(grad_fd) .+ 1e-8)) + @test isfinite(val_ad) + @test rel_err < 5e-5 + end + + @testset "Log prior/likelihood split and tempering" begin + println("[AutoMargTest] Log split: compiling model..."); + flush(stdout) + simple_def = @bugs begin + mu ~ Normal(0, 1) + z ~ Categorical(w[1:K]) + y ~ Normal(mu + delta[z], sigma) + end + + data = ( + K=2, + w=[0.3, 0.7], + delta=[0.0, 2.0], + sigma=1.0, + y=1.5, + ) + + model = compile(simple_def, data) + model = settrans(model, true) + model = set_evaluation_mode(model, UseAutoMarginalization()) + + θ = [0.0] # mu in transformed space (identity bijector) + _, stats = evaluate_with_marginalization_values!!(model, θ; temperature=0.4) + + expected_logprior = logpdf(Normal(0, 1), 0.0) + log_weighted = [ + log(data.w[i]) + logpdf(Normal(0.0 + data.delta[i], data.sigma), data.y) for + i in 1:data.K + ] + expected_loglik = LogExpFunctions.logsumexp(log_weighted) + + @test isapprox(stats.logprior, expected_logprior; atol=1e-10) + @test isapprox(stats.loglikelihood, expected_loglik; atol=1e-10) + @test isapprox( + stats.tempered_logjoint, expected_logprior + 0.4 * expected_loglik; atol=1e-10 + ) + + ad_model = ADgradient(AutoForwardDiff(), model) + val, grad = LogDensityProblems.logdensity_and_gradient(ad_model, θ) + @test isapprox(val, expected_logprior + expected_loglik; atol=1e-10) + function f_scalar(mu_val) + LogDensityProblems.logdensity(model, [mu_val]) + end + ϵ = 1e-6 + fd_grad = (f_scalar(θ[1] + ϵ) - f_scalar(θ[1] - ϵ)) / (2ϵ) + @test isapprox(grad[1], fd_grad; atol=1e-6) + end + + @testset "Efficiency smoke: AutoMarg+NUTS vs Graph+IndependentMH" begin + println("[AutoMargTest] Efficiency smoke: compiling models..."); + flush(stdout) + # Minimal smoke test to ensure both pipelines run (not a benchmark) + mixture_def = @bugs begin + w[1] = 0.3 + w[2] = 0.7 + mu[1] ~ Normal(-2, 5) + mu[2] ~ Normal(2, 5) + sigma[1] ~ Exponential(1) + sigma[2] ~ Exponential(1) + for i in 1:N + z[i] ~ Categorical(w[1:2]) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end + + data = (N=100, y=vcat(rand(Normal(-2, 1), 50), rand(Normal(2, 1), 50))) + + # Graph model with IndependentMH (quick smoke run) + model_graph = (m -> (m -> set_evaluation_mode(m, UseGraph()))(settrans(m, true)))( + compile(mixture_def, data) + ) + gibbs = JuliaBUGS.Gibbs(model_graph, JuliaBUGS.IndependentMH()) + println("[AutoMargTest] Efficiency smoke: sampling Graph+IMH..."); + flush(stdout) + chn_graph = AbstractMCMC.sample( + Random.default_rng(), + model_graph, + gibbs, + 10; + progress=false, + chain_type=MCMCChains.Chains, + ) + println("[AutoMargTest] Efficiency smoke: Graph+IMH done"); + flush(stdout) + @test length(chn_graph) == 10 + + # Auto-marginalized model with small-step NUTS + model_marg = ( + m -> (m -> set_evaluation_mode(m, UseAutoMarginalization()))(settrans(m, true)) + )( + compile(mixture_def, data) + ) + @test LogDensityProblems.dimension(model_marg) < + LogDensityProblems.dimension(model_graph) + # Run gradient-based sampling (NUTS) on the auto-marginalized AD-wrapped model + ad_model = ADgradient(AutoForwardDiff(), model_marg) + D = LogDensityProblems.dimension(model_marg) + θ0 = zeros(D) + println("[AutoMargTest] Efficiency smoke: sampling AutoMarg+NUTS..."); + flush(stdout) + samps = AbstractMCMC.sample( + Random.default_rng(), + ad_model, + NUTS(0.65), + 10; + progress=false, + n_adapts=0, + init_params=θ0, + discard_initial=0, + ) + println("[AutoMargTest] Efficiency smoke: AutoMarg+NUTS done"); + flush(stdout) + # Ensure sampling executed without errors + @test !isnothing(samps) + end +end diff --git a/JuliaBUGS/test/model/auto_marginalization_sampling.jl b/JuliaBUGS/test/model/auto_marginalization_sampling.jl new file mode 100644 index 000000000..00295ee4f --- /dev/null +++ b/JuliaBUGS/test/model/auto_marginalization_sampling.jl @@ -0,0 +1,95 @@ +using JuliaBUGS: @bugs, compile, settrans, getparams, initialize! +using JuliaBUGS.Model: + set_evaluation_mode, + UseAutoMarginalization, + parameters, + evaluate_with_marginalization_values!! + +@testset "Auto-Marginalization Sampling (NUTS)" begin + # 2-component GMM with fixed weights. Discrete z marginalized out. + mixture_def = @bugs begin + w[1] = 0.3 + w[2] = 0.7 + + # Moderately informative priors to aid identifiability + mu[1] ~ Normal(-2, 1) + mu[2] ~ Normal(2, 1) + sigma[1] ~ Exponential(1) + sigma[2] ~ Exponential(1) + + for i in 1:N + z[i] ~ Categorical(w[1:2]) + y[i] ~ Normal(mu[z[i]], sigma[z[i]]) + end + end + + # Generate data from the ground-truth parameters + N = 120 + true_w = [0.3, 0.7] + true_mu = [-2.0, 2.0] + true_sigma = [1.0, 1.0] + rng = StableRNG(1234) + # Partially observed assignments to break label switching and speed convergence + z_full = Vector{Int}(undef, N) + z_obs = Vector{Union{Int,Missing}}(undef, N) + # First 30 guaranteed component 1, last 30 guaranteed component 2 + for i in 1:30 + z_full[i] = 1 + z_obs[i] = 1 + end + for i in (N - 29):N + z_full[i] = 2 + z_obs[i] = 2 + end + # Middle indices drawn randomly + for i in 31:(N - 30) + z_full[i] = rand(rng, Categorical(true_w)) + z_obs[i] = missing + end + # Generate y + y = Vector{Float64}(undef, N) + for i in 1:N + y[i] = rand(rng, Normal(true_mu[z_full[i]], true_sigma[z_full[i]])) + end + + data = (N=N, y=y, z=z_obs) + + # Compile auto-marginalized model and wrap with AD for NUTS + model = ( + m -> (m -> set_evaluation_mode(m, UseAutoMarginalization()))(settrans(m, true)) + )( + compile(mixture_def, data) + ) + # Initialize near ground truth for faster convergence + initialize!(model, (; mu=[-2.0, 2.0], sigma=[1.0, 1.0])) + ad_model = ADgradient(AutoForwardDiff(), model) + + # Initialize at current transformed parameters + θ0 = getparams(model) + + # Short NUTS run to verify we recover means reasonably well + # Use more samples to tighten estimation accuracy + n_samples, n_adapts = 2000, 1000 + # Sample transitions (avoid requiring MCMCChains conversion here) + chain = AbstractMCMC.sample( + rng, + ad_model, + NUTS(0.65), + n_samples; + progress=false, + chain_type=MCMCChains.Chains, + n_adapts=n_adapts, + init_params=θ0, + discard_initial=n_adapts, + ) + + # Estimate means directly from Chains + means = mean(chain) + mu1_hat = means[Symbol("mu[1]")].nt.mean[1] + mu2_hat = means[Symbol("mu[2]")].nt.mean[1] + + # With unequal weights (0.3 vs 0.7), label switching is unlikely; allow generous tolerance + # Direct comparison to ground truth with absolute tolerance + @test isapprox(mu1_hat, true_mu[1]; atol=0.20) + @test isapprox(mu2_hat, true_mu[2]; atol=0.20) +end diff --git a/JuliaBUGS/test/model/frontier_cache_hmm.jl b/JuliaBUGS/test/model/frontier_cache_hmm.jl new file mode 100644 index 000000000..e4d8da4da --- /dev/null +++ b/JuliaBUGS/test/model/frontier_cache_hmm.jl @@ -0,0 +1,147 @@ +using Test +using JuliaBUGS +using JuliaBUGS: @bugs, compile, @varname +using JuliaBUGS.Model: + _precompute_minimal_cache_keys, _marginalize_recursive, smart_copy_evaluation_env + +@testset "Frontier cache for HMM under different orders" begin + println("[FrontierCacheTest] Start HMM frontier cache tests..."); + flush(stdout) + # Simple HMM with fixed emission parameters (no continuous params) + hmm_def = @bugs begin + mu[1] = 0.0 + mu[2] = 5.0 + sigma = 1.0 + + trans[1, 1] = 0.7 + trans[1, 2] = 0.3 + trans[2, 1] = 0.4 + trans[2, 2] = 0.6 + + pi[1] = 0.5 + pi[2] = 0.5 + + z[1] ~ Categorical(pi[1:2]) + for t in 2:T + p[t, 1] = trans[z[t - 1], 1] + p[t, 2] = trans[z[t - 1], 2] + z[t] ~ Categorical(p[t, :]) + end + + for t in 1:T + y[t] ~ Normal(mu[z[t]], sigma) + end + end + + T = 3 + data = (T=T, y=[0.1, 4.9, 5.1]) + model = compile(hmm_def, data) + + gd = model.graph_evaluation_data + n = length(gd.sorted_nodes) + + # Helper: index lookup for variables of interest + vn = Dict( + :z1 => @varname(z[1]), + :z2 => @varname(z[2]), + :z3 => @varname(z[3]), + :y1 => @varname(y[1]), + :y2 => @varname(y[2]), + :y3 => @varname(y[3]), + ) + idx = Dict{Symbol,Int}() + for (k, v) in vn + i = findfirst(==(v), gd.sorted_nodes) + @test i !== nothing # ensure nodes exist + idx[k] = i + end + + # Construct two evaluation orders as permutations of 1:n + # Interleaved: z1, y1, z2, y2, z3, y3, then the rest + priority_interleaved = [idx[:z1], idx[:y1], idx[:z2], idx[:y2], idx[:z3], idx[:y3]] + rest_interleaved = [i for i in 1:n if i ∉ priority_interleaved] + order_interleaved = vcat(priority_interleaved, rest_interleaved) + + # States-first: z1, z2, z3, y1, y2, y3, then the rest + priority_states_first = [idx[:z1], idx[:z2], idx[:z3], idx[:y1], idx[:y2], idx[:y3]] + rest_states_first = [i for i in 1:n if i ∉ priority_states_first] + order_states_first = vcat(priority_states_first, rest_states_first) + + # Precompute minimal keys for both orders + println("[FrontierCacheTest] Computing minimal keys (interleaved)..."); + flush(stdout) + keys_interleaved = _precompute_minimal_cache_keys(model, order_interleaved) + println("[FrontierCacheTest] Computing minimal keys (states-first)..."); + flush(stdout) + keys_states_first = _precompute_minimal_cache_keys(model, order_states_first) + + # Helper to map frontier indices back to a set of variable symbols we care about + function frontier_syms(keys, key_idx) + frontier = get(keys, key_idx, Int[]) + syms = Set{Symbol}() + for (name, i) in idx + if i in frontier + push!(syms, name) + end + end + return syms + end + + # Interleaved expectations: frontier size stays 1; y[t] depends on z[t] + @test frontier_syms(keys_interleaved, idx[:z1]) == Set{Symbol}() + @test frontier_syms(keys_interleaved, idx[:y1]) == Set([:z1]) + @test frontier_syms(keys_interleaved, idx[:z2]) == Set([:z1]) + @test frontier_syms(keys_interleaved, idx[:y2]) == Set([:z2]) + @test frontier_syms(keys_interleaved, idx[:z3]) == Set([:z2]) + @test frontier_syms(keys_interleaved, idx[:y3]) == Set([:z3]) + + # States-first expectations: frontier grows across z's, peaks at y1 + @test frontier_syms(keys_states_first, idx[:z1]) == Set{Symbol}() + @test frontier_syms(keys_states_first, idx[:z2]) == Set([:z1]) + @test frontier_syms(keys_states_first, idx[:z3]) == Set([:z1, :z2]) + @test frontier_syms(keys_states_first, idx[:y1]) == Set([:z1, :z2, :z3]) + @test frontier_syms(keys_states_first, idx[:y2]) == Set([:z2, :z3]) + @test frontier_syms(keys_states_first, idx[:y3]) == Set([:z3]) + + # Sanity: different orders should not change marginalized log-density + env = smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols) + params = Float64[] + # New marginalization uses parameter offsets/lengths and 2-tuple memo keys + param_offsets = Dict{VarName,Int}() + var_lengths = Dict{VarName,Int}() + memo1 = Dict{Tuple{Int,UInt64},Any}() + println("[FrontierCacheTest] Evaluating logp with interleaved order..."); + flush(stdout) + logp1 = _marginalize_recursive( + model, + env, + order_interleaved, + params, + param_offsets, + var_lengths, + memo1, + keys_interleaved, + ) + + env2 = smart_copy_evaluation_env(model.evaluation_env, model.mutable_symbols) + memo2 = Dict{Tuple{Int,UInt64},Any}() + println("[FrontierCacheTest] Evaluating logp with states-first order..."); + flush(stdout) + logp2 = _marginalize_recursive( + model, + env2, + order_states_first, + params, + param_offsets, + var_lengths, + memo2, + keys_states_first, + ) + println("[FrontierCacheTest] Done evaluations, comparing..."); + flush(stdout) + + @test isapprox(logp1, logp2; atol=1e-10) + + # And states-first should lead to equal or larger memo usage (worse frontier) + @test length(memo2) >= length(memo1) +end diff --git a/JuliaBUGS/test/runtests.jl b/JuliaBUGS/test/runtests.jl index c77a950e9..6d41660cb 100644 --- a/JuliaBUGS/test/runtests.jl +++ b/JuliaBUGS/test/runtests.jl @@ -27,6 +27,7 @@ using JuliaBUGS.BUGSPrimitives: mean using LinearAlgebra using LogDensityProblems using LogDensityProblemsAD +using LogExpFunctions using MacroTools using MetaGraphsNext using OrderedCollections @@ -40,7 +41,7 @@ using AdvancedMH using MCMCChains using ReverseDiff -JuliaBUGS.@bugs_primitive Beta Bernoulli Categorical Gamma InverseGamma Normal Uniform LogNormal Poisson +JuliaBUGS.@bugs_primitive Beta Bernoulli Categorical Exponential Gamma InverseGamma Normal Uniform LogNormal Poisson JuliaBUGS.@bugs_primitive Diagonal Dirichlet LKJ MvNormal JuliaBUGS.@bugs_primitive censored product_distribution truncated JuliaBUGS.@bugs_primitive fill ones zeros @@ -71,11 +72,14 @@ const TEST_GROUPS = OrderedDict{String,Function}( end, "log_density" => () -> begin include("model/evaluation.jl") + include("model/auto_marginalization.jl") + include("model/frontier_cache_hmm.jl") end, "inference" => () -> begin include("independent_mh.jl") include("ext/JuliaBUGSAdvancedHMCExt.jl") include("ext/JuliaBUGSMCMCChainsExt.jl") + include("model/auto_marginalization_sampling.jl") end, "inference_hmc" => () -> include("ext/JuliaBUGSAdvancedHMCExt.jl"), "inference_chains" => () -> include("ext/JuliaBUGSMCMCChainsExt.jl"),