You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Generating Julia function for log density evaluation (#278)
### Motivation
JuliaBUGS compiles BUGS programs into a directed probabilistic graphical
model (PGM), which implicitly defines the dependency structure between
variables. While this graph allows for execution strategies like
topological traversal and parallelization, a significant challenge
arises from the BUGS language semantics: every element within an array
can be treated as an individual random variable.
This fine-grained dependency structure means that a naive way to
generate Julia source based on the variable-level graph would often
require fully unrolling all loops. This approach is infeasible,
especially for large datasets or complex models, and poses significant
difficulties for automatic differentiation (AD) tools to analyze the
program.
### Proposed Changes
This PR introduces an initial implementation for generating a
specialized Julia function dedicated to computing the log density of the
model. The core idea is to operate on a higher level of abstraction than
individual variable nodes.
The algorithm proceeds as follows:
1. **Statement-Level Dependence Graph:** Construct a dependence graph
where nodes represent the *statements* in the BUGS program, rather than
individual random variables. Edges represent dependencies between these
statements.
2. **Acyclicity Check:** Verify if this statement-level graph contains
any cycles (including self-loops).
3. **Topological Sort & Loop Fission:** If the graph is acyclic, perform
a topological sort on the statements. Based on this order, restructure
the program, applying full loop fission. This separates loops operating
on different variables, ensuring that computations occur in a valid
dependency order.
4. **Code Generation:** Generate a Julia function based on the
topologically sorted and fissioned sequence of statements. Specialized
code is generated for:
* Deterministic assignments (`=`).
* Stochastic assignments / Priors (`~`).
* Observations (likelihood terms `≂ `).
The generated function takes a flattened vector of parameter values and
reconstructs them using `Bijectors.jl` to handle constraints and compute
log Jacobian adjustments, accumulating the log-prior and log-likelihood
terms.
### Example: Rats Model
Consider the classic "Rats" example:
**Original BUGS Code:**
```julia
begin
for i in 1:N
for j in 1:T
Y[i, j] ~ dnorm(mu[i, j], tau_c) # (1) Likelihood
mu[i, j] = alpha[i] + beta[i] * (x[j] - xbar) # (2) Deterministic
end
alpha[i] ~ dnorm(alpha_c, alpha_tau) # (3) Prior
beta[i] ~ dnorm(beta_c, beta_tau) # (4) Prior
end
tau_c ~ dgamma(0.001, 0.001) # (5) Prior
sigma = 1 / sqrt(tau_c) # (6) Deterministic
alpha_c ~ dnorm(0.0, 1.0e-6) # (7) Prior
alpha_tau ~ dgamma(0.001, 0.001) # (8) Prior
beta_c ~ dnorm(0.0, 1.0e-6) # (9) Prior
beta_tau ~ dgamma(0.001, 0.001) # (10) Prior
alpha0 = alpha_c - xbar * beta_c # (11) Deterministic
end
```
**Statement Dependence Graph:**
```mermaid
flowchart TB
8 --> 3
7 --> 3
10 --> 4
9 --> 11
9 --> 4
4 --> 11
3 --> 2
4 --> 2
2 --> 1
5 --> 1
5 --> 6
```
*(Note: Mermaid graph slightly adjusted for clarity based on variable
dependencies)*
**Sequential Representation (after Topological Sort & Loop Fission):**
This intermediate representation reflects the order determined by the
statement graph dependencies and separates the original nested loops.
```julia
begin
# Independent Priors first
tau_c ~ dgamma(0.001, 0.001) # (5)
alpha_c ~ dnorm(0.0, 1.0e-6) # (7)
alpha_tau ~ dgamma(0.001, 0.001) # (8)
beta_c ~ dnorm(0.0, 1.0e-6) # (9)
beta_tau ~ dgamma(0.001, 0.001) # (10)
# Deterministic nodes depending only on above priors
sigma = 1 / sqrt(tau_c) # (6) (Depends on 5)
alpha0 = alpha_c - xbar * beta_c # (11) (Depends on 7, 9)
# Priors depending on hyperparameters (loop fissioned)
for i in 1:N
alpha[i] ~ dnorm(alpha_c, alpha_tau) # (3) (Depends on 7, 8)
end
for i in 1:N
beta[i] ~ dnorm(beta_c, beta_tau) # (4) (Depends on 9, 10)
end
# Deterministic node depending on loop variables (loop fissioned)
for i in 1:N
for j in 1:T
mu[i, j] = alpha[i] + beta[i] * (x[j] - xbar) # (2) (Depends on 3, 4)
end
end
# Likelihood (loop fissioned)
for i in 1:N
for j in 1:T
Y[i, j] ≂ dnorm(mu[i, j], tau_c) # (1) (Depends on 2, 5)
end
end
end
```
**Generated Julia Log Density Function:**
This function takes the model environment (`__evaluation_env__`) and
flattened parameters (`__flattened_values__`), computes the log density
(`__logp__`), and handles necessary transformations via `Bijectors.jl`.
```julia
function var"##__compute_log_density__#236"(__evaluation_env__, __flattened_values__)
# Deconstruct the evaluation environment for easier access to variables
(; alpha, var"beta.c", xbar, sigma, alpha0, x, N, var"alpha.c", mu, var"alpha.tau", Y, T, beta, var"beta.tau", var"tau.c") = __evaluation_env__
# Initialize log density accumulator and flattened values index
__logp__ = 0.0
__current_idx__ = 1
# --- Process Prior Distributions (in topological order) ---
# Prior for beta_tau ~ dgamma(0.001, 0.001) [Statement 10]
# (Gamma distribution requires positive values -> needs transformation)
__dist__ = dgamma(0.001, 0.001)
__b__ = Bijectors.bijector(__dist__) # Typically LogBijector for Gamma
__b_inv__ = Bijectors.inverse(__b__)
__transformed_length__ = length(Bijectors.transformed(__dist__, __b__)) # Usually 1 for scalar
# Read value from flattened vector, reconstruct in transformed space
__reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1))
__current_idx__ += __transformed_length__
# Transform back to original space and get log Jacobian determinant
(__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__)
# Calculate log prior: logpdf in original space + log Jacobian adjustment
__logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__
# Accumulate log density
__logp__ = __logp__ + __logprior__
# Assign the calculated value (in original space) to the variable
var"beta.tau" = __value__ # Note: Using var"..." syntax for variables with dots
# Prior for beta_c ~ dnorm(0.0, 1.0e-6) [Statement 9]
# (Normal distribution is on R -> identity transform usually)
__dist__ = dnorm(0.0, 1.0e-6)
__b__ = Bijectors.bijector(__dist__) # Typically identity
__b_inv__ = Bijectors.inverse(__b__)
__transformed_length__ = if __b__ === identity
length(__dist__)
else
length(Bijectors.transformed(__dist__, __b__))
end
__reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1))
__current_idx__ += __transformed_length__
(__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__) # logjac is 0 for identity
__logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__
__logp__ = __logp__ + __logprior__
var"beta.c" = __value__
# Prior for alpha_tau ~ dgamma(0.001, 0.001) [Statement 8]
__dist__ = dgamma(0.001, 0.001)
__b__ = Bijectors.bijector(__dist__)
__b_inv__ = Bijectors.inverse(__b__)
__transformed_length__ = length(Bijectors.transformed(__dist__, __b__))
__reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1))
__current_idx__ += __transformed_length__
(__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__)
__logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__
__logp__ = __logp__ + __logprior__
var"alpha.tau" = __value__
# Prior for alpha_c ~ dnorm(0.0, 1.0e-6) [Statement 7]
__dist__ = dnorm(0.0, 1.0e-6)
__b__ = Bijectors.bijector(__dist__)
__b_inv__ = Bijectors.inverse(__b__)
__transformed_length__ = if __b__ === identity
length(__dist__)
else
length(Bijectors.transformed(__dist__, __b__))
end
__reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1))
__current_idx__ += __transformed_length__
(__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__)
__logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__
__logp__ = __logp__ + __logprior__
var"alpha.c" = __value__
# --- Process Deterministic Calculations (in topological order) ---
# alpha0 = alpha_c - xbar * beta_c [Statement 11]
# Depends on alpha_c [7] and beta_c [9], which are now available
alpha0 = var"alpha.c" - xbar * var"beta.c"
# --- Process Remaining Priors & Deterministic Nodes ---
# Prior for tau_c ~ dgamma(0.001, 0.001) [Statement 5]
__dist__ = dgamma(0.001, 0.001)
__b__ = Bijectors.bijector(__dist__)
__b_inv__ = Bijectors.inverse(__b__)
__transformed_length__ = length(Bijectors.transformed(__dist__, __b__))
__reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1))
__current_idx__ += __transformed_length__
(__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__)
__logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__
__logp__ = __logp__ + __logprior__
var"tau.c" = __value__
# sigma = 1 / sqrt(tau_c) [Statement 6]
# Depends on tau_c [5], which is now available
sigma = 1 / sqrt(var"tau.c")
# --- Process Looped Priors (Fissioned Loops) ---
# Prior for beta[i] ~ dnorm(beta_c, beta_tau) [Statement 4]
# Depends on beta_c [9] and beta_tau [10]
for i = 1:N
__dist__ = dnorm(var"beta.c", var"beta.tau") # Parameters are known values now
__b__ = Bijectors.bijector(__dist__)
__b_inv__ = Bijectors.inverse(__b__)
__transformed_length__ = if __b__ === identity
length(__dist__)
else
length(Bijectors.transformed(__dist__, __b__))
end
# Read value for beta[i]
__reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1))
__current_idx__ += __transformed_length__
(__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__)
__logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__
__logp__ = __logp__ + __logprior__
# Assign to the specific element beta[i]
beta[Int(i)] = __value__ # Using Int(i) defensively for indexing
end
# Prior for alpha[i] ~ dnorm(alpha_c, alpha_tau) [Statement 3]
# Depends on alpha_c [7] and alpha_tau [8]
for i = 1:N
__dist__ = dnorm(var"alpha.c", var"alpha.tau")
__b__ = Bijectors.bijector(__dist__)
__b_inv__ = Bijectors.inverse(__b__)
__transformed_length__ = if __b__ === identity
length(__dist__)
else
length(Bijectors.transformed(__dist__, __b__))
end
# Read value for alpha[i]
__reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1))
__current_idx__ += __transformed_length__
(__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__)
__logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__
__logp__ = __logp__ + __logprior__
# Assign to the specific element alpha[i]
alpha[Int(i)] = __value__
end
# --- Process Looped Deterministic Calculations (Fissioned Loops) ---
# mu[i, j] = alpha[i] + beta[i] * (x[j] - xbar) [Statement 2]
# Depends on alpha[i] [3] and beta[i] [4]
for i = 1:N
for j = 1:T
# Both alpha[i] and beta[i] are available from previous loops
# x[j] and xbar are from the environment (constants/data)
mu[Int(i), Int(j)] = alpha[Int(i)] + beta[Int(i)] * (x[Int(j)] - xbar)
end
end
# --- Process Likelihood Contribution (Fissioned Loop) ---
# Y[i, j] ~ dnorm(mu[i, j], tau_c) [Statement 1]
# Depends on mu[i, j] [2] and tau_c [5]
for i = 1:N
for j = 1:T
# mu[i,j] and tau_c are available
# Y[i,j] is data from the environment
# Add log-likelihood contribution directly (no Jacobian for data)
__logp__ += logpdf(dnorm(mu[Int(i), Int(j)], var"tau.c"), Y[Int(i), Int(j)])
end
end
# --- Final Assertion and Return ---
# Sanity check: ensure all values from the flattened vector have been consumed
@Assert __current_idx__ == length(__flattened_values__) + 1 "Indexing error: Not all parameter values were used or too many were expected."
# Return the total computed log density
return __logp__
end
```
### Performance
The generated function demonstrates significant performance improvements
and eliminates allocations compared to evaluating the log density
through the generic `LogDensityProblems.logdensity` interface, which
involves more overhead:
**`LogDensityProblems.logdensity` Benchmark:**
```julia
julia> @benchmark LogDensityProblems.logdensity($model, $x)
BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample.
Range (min … max): 33.708 μs … 9.379 ms ┊ GC (min … max): 0.00% … 99.21%
Time (median): 39.459 μs ┊ GC (median): 0.00%
Time (mean ± σ): 41.575 μs ± 122.243 μs ┊ GC (mean ± σ): 4.88% ± 1.71%
# ... histogram ...
Memory estimate: 60.53 KiB, allocs estimate: 1566.
```
**Directly Generated Function Benchmark:**
```julia
julia> @benchmark $(model.log_density_computation_function)($(model.evaluation_env), $x)
BenchmarkTools.Trial: 10000 samples with 103 evaluations per sample.
Range (min … max): 781.146 ns … 1.123 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 873.786 ns ┊ GC (median): 0.00%
Time (mean ± σ): 873.010 ns ± 28.609 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
# ... histogram ...
Memory estimate: 0 bytes, allocs estimate: 0.
```
**Gradient Performance (using Mooncake AD on the generated function):**
The generated function structure is also amenable to AD, yielding
efficient gradient computations:
```julia
BenchmarkTools.Trial: 10000 samples with 6 evaluations per sample.
Range (min … max): 5.160 μs … 18.174 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 5.938 μs ┊ GC (median): 0.00%
Time (mean ± σ): 5.958 μs ± 389.948 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
# ... histogram ...
Memory estimate: 1.19 KiB, allocs estimate: 6.
```
which is on par with Stan.
---------
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Penelope Yong <[email protected]>
0 commit comments