Skip to content

Commit ae083dc

Browse files
sunxd3github-actions[bot]penelopeysm
authored
Generating Julia function for log density evaluation (#278)
### Motivation JuliaBUGS compiles BUGS programs into a directed probabilistic graphical model (PGM), which implicitly defines the dependency structure between variables. While this graph allows for execution strategies like topological traversal and parallelization, a significant challenge arises from the BUGS language semantics: every element within an array can be treated as an individual random variable. This fine-grained dependency structure means that a naive way to generate Julia source based on the variable-level graph would often require fully unrolling all loops. This approach is infeasible, especially for large datasets or complex models, and poses significant difficulties for automatic differentiation (AD) tools to analyze the program. ### Proposed Changes This PR introduces an initial implementation for generating a specialized Julia function dedicated to computing the log density of the model. The core idea is to operate on a higher level of abstraction than individual variable nodes. The algorithm proceeds as follows: 1. **Statement-Level Dependence Graph:** Construct a dependence graph where nodes represent the *statements* in the BUGS program, rather than individual random variables. Edges represent dependencies between these statements. 2. **Acyclicity Check:** Verify if this statement-level graph contains any cycles (including self-loops). 3. **Topological Sort & Loop Fission:** If the graph is acyclic, perform a topological sort on the statements. Based on this order, restructure the program, applying full loop fission. This separates loops operating on different variables, ensuring that computations occur in a valid dependency order. 4. **Code Generation:** Generate a Julia function based on the topologically sorted and fissioned sequence of statements. Specialized code is generated for: * Deterministic assignments (`=`). * Stochastic assignments / Priors (`~`). * Observations (likelihood terms `≂ `). The generated function takes a flattened vector of parameter values and reconstructs them using `Bijectors.jl` to handle constraints and compute log Jacobian adjustments, accumulating the log-prior and log-likelihood terms. ### Example: Rats Model Consider the classic "Rats" example: **Original BUGS Code:** ```julia begin for i in 1:N for j in 1:T Y[i, j] ~ dnorm(mu[i, j], tau_c) # (1) Likelihood mu[i, j] = alpha[i] + beta[i] * (x[j] - xbar) # (2) Deterministic end alpha[i] ~ dnorm(alpha_c, alpha_tau) # (3) Prior beta[i] ~ dnorm(beta_c, beta_tau) # (4) Prior end tau_c ~ dgamma(0.001, 0.001) # (5) Prior sigma = 1 / sqrt(tau_c) # (6) Deterministic alpha_c ~ dnorm(0.0, 1.0e-6) # (7) Prior alpha_tau ~ dgamma(0.001, 0.001) # (8) Prior beta_c ~ dnorm(0.0, 1.0e-6) # (9) Prior beta_tau ~ dgamma(0.001, 0.001) # (10) Prior alpha0 = alpha_c - xbar * beta_c # (11) Deterministic end ``` **Statement Dependence Graph:** ```mermaid flowchart TB 8 --> 3 7 --> 3 10 --> 4 9 --> 11 9 --> 4 4 --> 11 3 --> 2 4 --> 2 2 --> 1 5 --> 1 5 --> 6 ``` *(Note: Mermaid graph slightly adjusted for clarity based on variable dependencies)* **Sequential Representation (after Topological Sort & Loop Fission):** This intermediate representation reflects the order determined by the statement graph dependencies and separates the original nested loops. ```julia begin # Independent Priors first tau_c ~ dgamma(0.001, 0.001) # (5) alpha_c ~ dnorm(0.0, 1.0e-6) # (7) alpha_tau ~ dgamma(0.001, 0.001) # (8) beta_c ~ dnorm(0.0, 1.0e-6) # (9) beta_tau ~ dgamma(0.001, 0.001) # (10) # Deterministic nodes depending only on above priors sigma = 1 / sqrt(tau_c) # (6) (Depends on 5) alpha0 = alpha_c - xbar * beta_c # (11) (Depends on 7, 9) # Priors depending on hyperparameters (loop fissioned) for i in 1:N alpha[i] ~ dnorm(alpha_c, alpha_tau) # (3) (Depends on 7, 8) end for i in 1:N beta[i] ~ dnorm(beta_c, beta_tau) # (4) (Depends on 9, 10) end # Deterministic node depending on loop variables (loop fissioned) for i in 1:N for j in 1:T mu[i, j] = alpha[i] + beta[i] * (x[j] - xbar) # (2) (Depends on 3, 4) end end # Likelihood (loop fissioned) for i in 1:N for j in 1:T Y[i, j] ≂ dnorm(mu[i, j], tau_c) # (1) (Depends on 2, 5) end end end ``` **Generated Julia Log Density Function:** This function takes the model environment (`__evaluation_env__`) and flattened parameters (`__flattened_values__`), computes the log density (`__logp__`), and handles necessary transformations via `Bijectors.jl`. ```julia function var"##__compute_log_density__#236"(__evaluation_env__, __flattened_values__) # Deconstruct the evaluation environment for easier access to variables (; alpha, var"beta.c", xbar, sigma, alpha0, x, N, var"alpha.c", mu, var"alpha.tau", Y, T, beta, var"beta.tau", var"tau.c") = __evaluation_env__ # Initialize log density accumulator and flattened values index __logp__ = 0.0 __current_idx__ = 1 # --- Process Prior Distributions (in topological order) --- # Prior for beta_tau ~ dgamma(0.001, 0.001) [Statement 10] # (Gamma distribution requires positive values -> needs transformation) __dist__ = dgamma(0.001, 0.001) __b__ = Bijectors.bijector(__dist__) # Typically LogBijector for Gamma __b_inv__ = Bijectors.inverse(__b__) __transformed_length__ = length(Bijectors.transformed(__dist__, __b__)) # Usually 1 for scalar # Read value from flattened vector, reconstruct in transformed space __reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1)) __current_idx__ += __transformed_length__ # Transform back to original space and get log Jacobian determinant (__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__) # Calculate log prior: logpdf in original space + log Jacobian adjustment __logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__ # Accumulate log density __logp__ = __logp__ + __logprior__ # Assign the calculated value (in original space) to the variable var"beta.tau" = __value__ # Note: Using var"..." syntax for variables with dots # Prior for beta_c ~ dnorm(0.0, 1.0e-6) [Statement 9] # (Normal distribution is on R -> identity transform usually) __dist__ = dnorm(0.0, 1.0e-6) __b__ = Bijectors.bijector(__dist__) # Typically identity __b_inv__ = Bijectors.inverse(__b__) __transformed_length__ = if __b__ === identity length(__dist__) else length(Bijectors.transformed(__dist__, __b__)) end __reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1)) __current_idx__ += __transformed_length__ (__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__) # logjac is 0 for identity __logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__ __logp__ = __logp__ + __logprior__ var"beta.c" = __value__ # Prior for alpha_tau ~ dgamma(0.001, 0.001) [Statement 8] __dist__ = dgamma(0.001, 0.001) __b__ = Bijectors.bijector(__dist__) __b_inv__ = Bijectors.inverse(__b__) __transformed_length__ = length(Bijectors.transformed(__dist__, __b__)) __reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1)) __current_idx__ += __transformed_length__ (__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__) __logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__ __logp__ = __logp__ + __logprior__ var"alpha.tau" = __value__ # Prior for alpha_c ~ dnorm(0.0, 1.0e-6) [Statement 7] __dist__ = dnorm(0.0, 1.0e-6) __b__ = Bijectors.bijector(__dist__) __b_inv__ = Bijectors.inverse(__b__) __transformed_length__ = if __b__ === identity length(__dist__) else length(Bijectors.transformed(__dist__, __b__)) end __reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1)) __current_idx__ += __transformed_length__ (__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__) __logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__ __logp__ = __logp__ + __logprior__ var"alpha.c" = __value__ # --- Process Deterministic Calculations (in topological order) --- # alpha0 = alpha_c - xbar * beta_c [Statement 11] # Depends on alpha_c [7] and beta_c [9], which are now available alpha0 = var"alpha.c" - xbar * var"beta.c" # --- Process Remaining Priors & Deterministic Nodes --- # Prior for tau_c ~ dgamma(0.001, 0.001) [Statement 5] __dist__ = dgamma(0.001, 0.001) __b__ = Bijectors.bijector(__dist__) __b_inv__ = Bijectors.inverse(__b__) __transformed_length__ = length(Bijectors.transformed(__dist__, __b__)) __reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1)) __current_idx__ += __transformed_length__ (__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__) __logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__ __logp__ = __logp__ + __logprior__ var"tau.c" = __value__ # sigma = 1 / sqrt(tau_c) [Statement 6] # Depends on tau_c [5], which is now available sigma = 1 / sqrt(var"tau.c") # --- Process Looped Priors (Fissioned Loops) --- # Prior for beta[i] ~ dnorm(beta_c, beta_tau) [Statement 4] # Depends on beta_c [9] and beta_tau [10] for i = 1:N __dist__ = dnorm(var"beta.c", var"beta.tau") # Parameters are known values now __b__ = Bijectors.bijector(__dist__) __b_inv__ = Bijectors.inverse(__b__) __transformed_length__ = if __b__ === identity length(__dist__) else length(Bijectors.transformed(__dist__, __b__)) end # Read value for beta[i] __reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1)) __current_idx__ += __transformed_length__ (__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__) __logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__ __logp__ = __logp__ + __logprior__ # Assign to the specific element beta[i] beta[Int(i)] = __value__ # Using Int(i) defensively for indexing end # Prior for alpha[i] ~ dnorm(alpha_c, alpha_tau) [Statement 3] # Depends on alpha_c [7] and alpha_tau [8] for i = 1:N __dist__ = dnorm(var"alpha.c", var"alpha.tau") __b__ = Bijectors.bijector(__dist__) __b_inv__ = Bijectors.inverse(__b__) __transformed_length__ = if __b__ === identity length(__dist__) else length(Bijectors.transformed(__dist__, __b__)) end # Read value for alpha[i] __reconstructed_value__ = JuliaBUGS.reconstruct(__b_inv__, __dist__, view(__flattened_values__, __current_idx__:(__current_idx__ + __transformed_length__) - 1)) __current_idx__ += __transformed_length__ (__value__, __logjac__) = Bijectors.with_logabsdet_jacobian(__b_inv__, __reconstructed_value__) __logprior__ = Distributions.logpdf(__dist__, __value__) + __logjac__ __logp__ = __logp__ + __logprior__ # Assign to the specific element alpha[i] alpha[Int(i)] = __value__ end # --- Process Looped Deterministic Calculations (Fissioned Loops) --- # mu[i, j] = alpha[i] + beta[i] * (x[j] - xbar) [Statement 2] # Depends on alpha[i] [3] and beta[i] [4] for i = 1:N for j = 1:T # Both alpha[i] and beta[i] are available from previous loops # x[j] and xbar are from the environment (constants/data) mu[Int(i), Int(j)] = alpha[Int(i)] + beta[Int(i)] * (x[Int(j)] - xbar) end end # --- Process Likelihood Contribution (Fissioned Loop) --- # Y[i, j] ~ dnorm(mu[i, j], tau_c) [Statement 1] # Depends on mu[i, j] [2] and tau_c [5] for i = 1:N for j = 1:T # mu[i,j] and tau_c are available # Y[i,j] is data from the environment # Add log-likelihood contribution directly (no Jacobian for data) __logp__ += logpdf(dnorm(mu[Int(i), Int(j)], var"tau.c"), Y[Int(i), Int(j)]) end end # --- Final Assertion and Return --- # Sanity check: ensure all values from the flattened vector have been consumed @Assert __current_idx__ == length(__flattened_values__) + 1 "Indexing error: Not all parameter values were used or too many were expected." # Return the total computed log density return __logp__ end ``` ### Performance The generated function demonstrates significant performance improvements and eliminates allocations compared to evaluating the log density through the generic `LogDensityProblems.logdensity` interface, which involves more overhead: **`LogDensityProblems.logdensity` Benchmark:** ```julia julia> @benchmark LogDensityProblems.logdensity($model, $x) BenchmarkTools.Trial: 10000 samples with 1 evaluation per sample. Range (min … max): 33.708 μs … 9.379 ms ┊ GC (min … max): 0.00% … 99.21% Time (median): 39.459 μs ┊ GC (median): 0.00% Time (mean ± σ): 41.575 μs ± 122.243 μs ┊ GC (mean ± σ): 4.88% ± 1.71% # ... histogram ... Memory estimate: 60.53 KiB, allocs estimate: 1566. ``` **Directly Generated Function Benchmark:** ```julia julia> @benchmark $(model.log_density_computation_function)($(model.evaluation_env), $x) BenchmarkTools.Trial: 10000 samples with 103 evaluations per sample. Range (min … max): 781.146 ns … 1.123 μs ┊ GC (min … max): 0.00% … 0.00% Time (median): 873.786 ns ┊ GC (median): 0.00% Time (mean ± σ): 873.010 ns ± 28.609 ns ┊ GC (mean ± σ): 0.00% ± 0.00% # ... histogram ... Memory estimate: 0 bytes, allocs estimate: 0. ``` **Gradient Performance (using Mooncake AD on the generated function):** The generated function structure is also amenable to AD, yielding efficient gradient computations: ```julia BenchmarkTools.Trial: 10000 samples with 6 evaluations per sample. Range (min … max): 5.160 μs … 18.174 μs ┊ GC (min … max): 0.00% … 0.00% Time (median): 5.938 μs ┊ GC (median): 0.00% Time (mean ± σ): 5.958 μs ± 389.948 ns ┊ GC (mean ± σ): 0.00% ± 0.00% # ... histogram ... Memory estimate: 1.19 KiB, allocs estimate: 6. ``` which is on par with Stan. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Penelope Yong <[email protected]>
1 parent d5eeb6d commit ae083dc

35 files changed

+5381
-491
lines changed

.github/workflows/Tests.yml

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,20 @@ jobs:
1212
test:
1313
name: Julia ${{ matrix.version }} on ${{ matrix.os }} (${{ matrix.arch }})
1414
runs-on: ${{ matrix.os }}
15-
continue-on-error: ${{ matrix.version == 'nightly' }}
15+
continue-on-error: ${{ matrix.version == 'pre' }}
1616
strategy:
17-
fail-fast: false
1817
matrix:
19-
version: ['1', '1.10', 'nightly']
20-
os: [ubuntu-latest, windows-latest]
21-
arch: [x64]
18+
version:
19+
- '1'
20+
- 'min'
21+
- 'pre'
22+
os:
23+
- ubuntu-latest
24+
- windows-latest
25+
arch:
26+
- x64
2227
include:
23-
- version: '1'
28+
- version: 'min'
2429
os: ubuntu-latest
2530
arch: x64
2631
coverage: true
@@ -60,6 +65,11 @@ jobs:
6065
env:
6166
TEST_GROUP: "log_density"
6267

68+
- name: Running `source_gen` tests
69+
uses: julia-actions/julia-runtest@v1
70+
env:
71+
TEST_GROUP: "source_gen"
72+
6373
- name: Running `gibbs` tests
6474
uses: nick-fields/retry@v3
6575
with:

Project.toml

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "JuliaBUGS"
22
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
3-
version = "0.9"
3+
version = "0.9.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -39,11 +39,10 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3939
JuliaBUGSAdvancedHMCExt = ["AdvancedHMC", "MCMCChains"]
4040
JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"]
4141
JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"]
42-
JuliaBUGSGraphPlotExt = ["GraphPlot"]
43-
JuliaBUGSMCMCChainsExt = ["MCMCChains"]
42+
JuliaBUGSGraphPlotExt = "GraphPlot"
43+
JuliaBUGSMCMCChainsExt = "MCMCChains"
4444

4545
[compat]
46-
ADTypes = "1.6"
4746
AbstractMCMC = "5"
4847
AbstractPPL = "0.8.4, 0.9, 0.10, 0.11"
4948
Accessors = "0.1"
@@ -52,6 +51,7 @@ AdvancedMH = "0.8"
5251
BangBang = "0.4.1"
5352
Bijectors = "0.13, 0.14, 0.15.5"
5453
ChainRules = "1"
54+
DifferentiationInterface = "0.6.42"
5555
Distributions = "0.23.8, 0.24, 0.25"
5656
Documenter = "0.27, 1"
5757
GLMakie = "0.10, 0.11"
@@ -67,6 +67,7 @@ LogExpFunctions = "0.3"
6767
MCMCChains = "6"
6868
MacroTools = "0.5"
6969
MetaGraphsNext = "0.6, 0.7"
70+
Mooncake = "0.4"
7071
OrderedCollections = "1"
7172
PDMats = "0.10, 0.11"
7273
Serialization = "1.10"
@@ -76,15 +77,27 @@ Statistics = "1.10"
7677
julia = "1.10.8"
7778

7879
[extras]
79-
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
8080
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
8181
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
8282
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
8383
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
84+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
8485
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
8586
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
87+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
8688
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
8789
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8890

8991
[targets]
90-
test = ["AbstractMCMC", "ADTypes", "AdvancedHMC", "AdvancedMH", "ChainRules", "MCMCChains", "LogDensityProblemsAD", "ReverseDiff", "Test"]
92+
test = [
93+
"AbstractMCMC",
94+
"AdvancedHMC",
95+
"AdvancedMH",
96+
"ChainRules",
97+
"DifferentiationInterface",
98+
"LogDensityProblemsAD",
99+
"MCMCChains",
100+
"Mooncake",
101+
"ReverseDiff",
102+
"Test"
103+
]

benchmark/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
[deps]
2-
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
32
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
43
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
54
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
66
JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
77
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
88
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
99
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
10+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1011
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1112
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1213
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

benchmark/benchmark.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
module Benchmark
2-
31
using Pkg
42
Pkg.develop(; path=joinpath(@__DIR__, ".."))
53

64
using JuliaBUGS
7-
using ADTypes
8-
using ReverseDiff
5+
6+
using DifferentiationInterface
7+
using Mooncake: Mooncake
8+
99
using MetaGraphsNext
1010
using BridgeStan
1111
using StanLogDensityProblems
@@ -96,5 +96,3 @@ function _print_results_table(
9696
backend=backend,
9797
)
9898
end
99-
100-
end

benchmark/juliabugs.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,29 @@ function _create_JuliaBUGS_model(model_name::Symbol)
7171
return compile(model_def, data, inits)
7272
end
7373

74-
function benchmark_JuliaBUGS_model(model::JuliaBUGS.BUGSModel)
75-
ad_model = ADgradient(AutoReverseDiff(true), model)
74+
# ! writing a _function_ to benchmark all models won't work because of world-age error
75+
76+
function benchmark_JuliaBUGS_model_with_Mooncake(model::JuliaBUGS.BUGSModel)
77+
# p = Base.Fix1(LogDensityProblems.logdensity, model)
78+
p = Base.Fix1(model.log_density_computation_function, model.evaluation_env)
79+
backend = AutoMooncake(; config=nothing)
7680
dim = LogDensityProblems.dimension(model)
7781
params_values = JuliaBUGS.getparams(model)
78-
density_time = Chairmarks.@be LogDensityProblems.logdensity($ad_model, $params_values)
79-
density_and_gradient_time = Chairmarks.@be LogDensityProblems.logdensity_and_gradient(
80-
$ad_model, $params_values
81-
)
82+
prep = prepare_gradient(p, backend, params_values)
83+
density_time = Chairmarks.@be $p($params_values)
84+
density_and_gradient_time = Chairmarks.@be gradient($p, $prep, $backend, $params_values)
8285
return BenchmarkResult(:juliabugs, dim, density_time, density_and_gradient_time)
8386
end
8487

85-
# writing a _function_ to benchmark all models won't work because of worldage error
88+
# function benchmark_JuliaBUGS_model_with_Enzyme(model::JuliaBUGS.BUGSModel)
89+
# f(params, model) = LogDensityProblems.logdensity(model, params)
90+
# backend = AutoEnzyme()
91+
# dim = LogDensityProblems.dimension(model)
92+
# params_values = JuliaBUGS.getparams(model)
93+
# prep = prepare_gradient(f, backend, params_values, Constant(model))
94+
# density_time = Chairmarks.@be LogDensityProblems.logdensity($model, $params_values)
95+
# density_and_gradient_time = Chairmarks.@be gradient(
96+
# $f, $prep, $backend, $params_values, $(Constant(model))
97+
# )
98+
# return BenchmarkResult(:juliabugs_enzyme, dim, density_time, density_and_gradient_time)
99+
# end

benchmark/run_benchmarks.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,35 @@
11
include("benchmark.jl")
2-
using OrderedCollections
32

43
examples_to_benchmark = [
54
:rats, :pumps, :bones, :oxford, :epil, :lsat, :schools, :beetles, :air
65
]
76

8-
stan_results = Benchmark.benchmark_Stan_models(examples_to_benchmark)
7+
stan_results = benchmark_Stan_models(examples_to_benchmark)
98

109
juliabugs_models = [
11-
Benchmark._create_JuliaBUGS_model(model_name) for model_name in examples_to_benchmark
10+
JuliaBUGS.set_evaluation_mode(
11+
_create_JuliaBUGS_model(model_name), JuliaBUGS.UseGeneratedLogDensityFunction()
12+
) for model_name in examples_to_benchmark
1213
]
13-
juliabugs_results = OrderedDict{Symbol,Benchmark.BenchmarkResult}()
14+
juliabugs_results = OrderedDict{Symbol,BenchmarkResult}()
1415
for (model_name, model) in zip(examples_to_benchmark, juliabugs_models)
15-
@info "Benchmarking $model_name"
16-
juliabugs_results[model_name] = Benchmark.benchmark_JuliaBUGS_model(model)
16+
@info "Benchmarking $model_name with Mooncake"
17+
juliabugs_results[model_name] = benchmark_JuliaBUGS_model_with_Mooncake(model)
1718
end
1819

20+
# juliabugs_enzyme_results = OrderedDict{Symbol,BenchmarkResult}()
21+
# for (model_name, model) in zip(examples_to_benchmark, juliabugs_models)
22+
# @info "Benchmarking $model_name with Enzyme"
23+
# try
24+
# juliabugs_enzyme_results[model_name] = benchmark_JuliaBUGS_model_with_Enzyme(model)
25+
# catch e
26+
# @warn "Error benchmarking $model_name with Enzyme: $e"
27+
# end
28+
# end
29+
1930
println("### Stan results:")
20-
Benchmark._print_results_table(stan_results; backend=Val(:markdown))
21-
println("### JuliaBUGS results:")
22-
Benchmark._print_results_table(juliabugs_results; backend=Val(:markdown))
31+
_print_results_table(stan_results; backend=Val(:markdown))
32+
println("### JuliaBUGS Mooncake results:")
33+
_print_results_table(juliabugs_results; backend=Val(:markdown))
34+
# println("### JuliaBUGS Enzyme results:")
35+
# _print_results_table(juliabugs_enzyme_results; backend=Val(:markdown))

0 commit comments

Comments
 (0)