From 8d460c0dcb45c0668afb334c8f0b749a404088b5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Jul 2025 16:06:06 +0100 Subject: [PATCH 1/8] Make `run_ad` return both primal and grad --- HISTORY.md | 5 ++- benchmarks/benchmarks.jl | 53 ++++++++++++-------------- benchmarks/src/DynamicPPLBenchmarks.jl | 49 +++--------------------- src/test_utils/ad.jl | 39 +++++++++++++++---- 4 files changed, 66 insertions(+), 80 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index b59d8dd7f..7a581e0cf 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -24,9 +24,12 @@ Please see the API documentation for more details. There is now also an `rng` keyword argument to help seed parameter generation. -Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. +Instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient. Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`. +Finally, the `ADResult` object returned by `run_ad` now has both `grad_time` and `primal_time` fields, which contain the time it took to calculate the gradient of logp and logp itself. +Previously there was only a single `time_vs_primal` field which represented the ratio of these two. + ### `DynamicPPL.TestUtils.check_model` You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`. diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index b733d810c..5400369a7 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -1,7 +1,6 @@ -using Pkg - -using DynamicPPLBenchmarks: Models, make_suite, model_dimension -using BenchmarkTools: @benchmark, median, run +using DynamicPPLBenchmarks: Models, to_backend, make_varinfo, model_dimension +using DynamicPPL.TestUtils.AD: run_ad, NoTest +using Chairmarks: @be using PrettyTables: PrettyTables, ft_printf using StableRNGs: StableRNG @@ -35,48 +34,45 @@ chosen_combinations = [ Models.simple_assume_observe(randn(rng)), :typed, :forwarddiff, - false, ), - ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), - ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), - ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), - ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), - ("Multivariate 10k", multivariate10k, :typed, :mooncake, true), - ("Dynamic", Models.dynamic(), :typed, :mooncake, true), - ("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), - ("LDA", lda_instance, :typed, :reversediff, true), + ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff), + ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff), + ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff), + ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff), + ("Smorgasbord", smorgasbord_instance, :typed, :reversediff), + ("Smorgasbord", smorgasbord_instance, :typed, :mooncake), + ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake), + ("Multivariate 1k", multivariate1k, :typed, :mooncake), + ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake), + ("Multivariate 10k", multivariate10k, :typed, :mooncake), + ("Dynamic", Models.dynamic(), :typed, :mooncake), + ("Submodel", Models.parent(randn(rng)), :typed, :mooncake), + ("LDA", lda_instance, :typed, :reversediff), ] # Time running a model-like function that does not use DynamicPPL, as a reference point. # Eval timings will be relative to this. reference_time = begin obs = randn(rng) - median(@benchmark Models.simple_assume_observe_non_model(obs)).time + median(@be Models.simple_assume_observe_non_model(obs)).time end results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations @info "Running benchmark for $model_name" - suite = make_suite(model, varinfo_choice, adbackend, islinked) - results = run(suite) - eval_time = median(results["evaluation"]).time - relative_eval_time = eval_time / reference_time - ad_eval_time = median(results["gradient"]).time - relative_ad_eval_time = ad_eval_time / eval_time + adtype = to_backend(adbackend) + varinfo = make_varinfo(model, varinfo_choice) + ad_result = run_ad(model, adtype; test=NoTest(), benchmark=true, varinfo=varinfo) + relative_eval_time = ad_result.primal_time / reference_time + relative_ad_eval_time = ad_result.grad_time / ad_result.primal_time push!( results_table, ( model_name, - model_dimension(model, islinked), + length(varinfo[:]), string(adbackend), string(varinfo_choice), - islinked, relative_eval_time, relative_ad_eval_time, ), @@ -89,7 +85,6 @@ header = [ "Dimension", "AD Backend", "VarInfo Type", - "Linked", "Eval Time / Ref Time", "AD Time / Eval Time", ] @@ -97,6 +92,6 @@ PrettyTables.pretty_table( table_matrix; header=header, tf=PrettyTables.tf_markdown, - formatters=ft_printf("%.1f", [6, 7]), + formatters=ft_printf("%.1f", [5, 6]), crop=:none, # Always print the whole table, even if it doesn't fit in the terminal. ) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 8c5032ace..80bf15f24 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -14,21 +14,7 @@ using StableRNGs: StableRNG include("./Models.jl") using .Models: Models -export Models, make_suite, model_dimension - -""" - model_dimension(model, islinked) - -Return the dimension of `model`, accounting for linking, if any. -""" -function model_dimension(model, islinked) - vi = VarInfo() - model(StableRNG(23), vi) - if islinked - vi = DynamicPPL.link(vi, model) - end - return length(vi[:]) -end +export Models, to_backend, make_varinfo # Utility functions for representing AD backends using symbols. # Copied from TuringBenchmarking.jl. @@ -48,24 +34,20 @@ function to_backend(x::Union{AbstractString,Symbol}) end """ - make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) + make_varinfo(model, varinfo_choice::Symbol) -Create a benchmark suite for `model` using the selected varinfo type and AD backend. +Create a VarInfo for the given `model` using the selected varinfo type. Available varinfo choices: • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` • `:typed` → uses `DynamicPPL.typed_varinfo(model)` • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) -The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`). - -`islinked` determines whether to link the VarInfo for evaluation. +The VarInfo is always linked. """ -function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) +function make_varinfo(model::Model, varinfo_choice::Symbol, adbackend::Symbol) rng = StableRNG(23) - suite = BenchmarkGroup() - vi = if varinfo_choice == :untyped DynamicPPL.untyped_varinfo(rng, model) elseif varinfo_choice == :typed @@ -80,26 +62,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: error("Unknown varinfo choice: $varinfo_choice") end - adbackend = to_backend(adbackend) - - if islinked - vi = DynamicPPL.link(vi, model) - end - - f = DynamicPPL.LogDensityFunction( - model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend - ) - # The parameters at which we evaluate f. - θ = vi[:] - - # Run once to trigger compilation. - LogDensityProblems.logdensity_and_gradient(f, θ) - suite["gradient"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f, $θ) - - # Also benchmark just standard model evaluation because why not. - suite["evaluation"] = @benchmarkable $(LogDensityProblems.logdensity)($f, $θ) - - return suite + return DynamicPPL.link(vi, model) end end # module diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 1ac33a481..95b5d6cd0 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -109,8 +109,11 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa value_actual::Tresult "The gradient of logp (calculated using `adtype`)" grad_actual::Vector{Tresult} - "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" - time_vs_primal::Union{Nothing,Tresult} + "If benchmarking was requested, the time taken by the AD backend to evaluate the gradient + of logp" + grad_time::Union{Nothing,Tresult} + "If benchmarking was requested, the time taken by the AD backend to evaluate logp" + primal_time::Union{Nothing,Tresult} end """ @@ -121,6 +124,8 @@ end benchmark=false, atol::AbstractFloat=1e-8, rtol::AbstractFloat=sqrt(eps()), + getlogdensity::Function=getlogjoint_internal, + rng::AbstractRNG=default_rng(), varinfo::AbstractVarInfo=link(VarInfo(model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, verbose=true, @@ -174,6 +179,21 @@ Everything else is optional, and can be categorised into several groups: prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. +3. _Which type of logp is being calculated._ + + By default, `run_ad` evaluates the 'internal log joint density' of the model, + i.e., the log joint density in the unconstrained space. Thus, for example, in + + @model f() = x ~ LogNormal() + + the internal log joint density is `logpdf(Normal(), log(x))`. This is the + relevant log density for e.g. Hamiltonian Monte Carlo samplers and is therefore + the most useful to test. + + If you want the log joint density in the original model parameterisation, you + can use `getlogjoint`. Likewise, if you want only the prior or likelihood, + you can use `getlogprior` or `getloglikelihood`, respectively. + 3. _How to specify the results to compare against._ Once logp and its gradient has been calculated with the specified `adtype`, @@ -277,12 +297,16 @@ function run_ad( end # Benchmark - time_vs_primal = if benchmark + grad_time, primal_time = if benchmark primal_benchmark = @be (ldf, params) logdensity(_[1], _[2]) grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2]) - t = median(grad_benchmark).time / median(primal_benchmark).time - verbose && println("grad / primal : $(t)") - t + median_primal = median(primal_benchmark).time + median_grad = median(grad_benchmark).time + r(f) = round(f; sigdigits=4) + verbose && println( + "grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))", + ) + (median_grad, median_primal) else nothing end @@ -299,7 +323,8 @@ function run_ad( grad_true, value, grad, - time_vs_primal, + grad_time, + primal_time, ) end From 90127abd21cbf17afde394cdcc968cd3dbbc909b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Jul 2025 16:13:21 +0100 Subject: [PATCH 2/8] Clean up deps --- benchmarks/Project.toml | 4 ---- benchmarks/src/DynamicPPLBenchmarks.jl | 12 +++++------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 3d14d03ff..9a0ccfadf 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -4,12 +4,10 @@ version = "0.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -20,11 +18,9 @@ DynamicPPL = {path = "../"} [compat] ADTypes = "1.14.0" -BenchmarkTools = "1.6.0" Distributions = "0.25.117" DynamicPPL = "0.37" ForwardDiff = "0.10.38, 1" -LogDensityProblems = "2.1.2" Mooncake = "0.4" PrettyTables = "2.4.0" ReverseDiff = "1.15.3" diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 80bf15f24..4eee86ef7 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,10 +1,8 @@ module DynamicPPLBenchmarks -using DynamicPPL: VarInfo, SimpleVarInfo, VarName -using BenchmarkTools: BenchmarkGroup, @benchmarkable +using DynamicPPL: Model, VarInfo, SimpleVarInfo using DynamicPPL: DynamicPPL using ADTypes: ADTypes -using LogDensityProblems: LogDensityProblems using ForwardDiff: ForwardDiff using Mooncake: Mooncake @@ -55,14 +53,14 @@ function make_varinfo(model::Model, varinfo_choice::Symbol, adbackend::Symbol) elseif varinfo_choice == :simple_namedtuple SimpleVarInfo{Float64}(model(rng)) elseif varinfo_choice == :simple_dict - retvals = model(rng) - vns = [VarName{k}() for k in keys(retvals)] - SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) + vi = DynamicPPL.typed_varinfo(rng, model) + vals = DynamicPPL.values_as(vi, Dict) + SimpleVarInfo{Float64}(vals) else error("Unknown varinfo choice: $varinfo_choice") end - return DynamicPPL.link(vi, model) + return DynamicPPL.link!!(vi, model) end end # module From 2f82a033e6b05754ca99df9ba77f4cd16667013c Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Jul 2025 16:18:30 +0100 Subject: [PATCH 3/8] add Chairmarks --- benchmarks/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 9a0ccfadf..96137bb29 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -4,6 +4,7 @@ version = "0.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -18,6 +19,7 @@ DynamicPPL = {path = "../"} [compat] ADTypes = "1.14.0" +Chairmarks = "1.3.1" Distributions = "0.25.117" DynamicPPL = "0.37" ForwardDiff = "0.10.38, 1" From 637e4bc97be1a592926272f1de826f7998573df0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Jul 2025 16:19:52 +0100 Subject: [PATCH 4/8] Remove unused ReverseDiffCompiled --- benchmarks/src/DynamicPPLBenchmarks.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 4eee86ef7..5c0ee930c 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -19,7 +19,6 @@ export Models, to_backend, make_varinfo const SYMBOL_TO_BACKEND = Dict( :forwarddiff => ADTypes.AutoForwardDiff(), :reversediff => ADTypes.AutoReverseDiff(; compile=false), - :reversediff_compiled => ADTypes.AutoReverseDiff(; compile=true), :mooncake => ADTypes.AutoMooncake(; config=nothing), ) From 2686fccbb104a736fa494401c5d0cfe5607dcb34 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Jul 2025 16:32:01 +0100 Subject: [PATCH 5/8] fix more bugs --- benchmarks/Project.toml | 2 ++ benchmarks/benchmarks.jl | 1 + src/test_utils/ad.jl | 2 +- 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 96137bb29..73ace2e56 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -13,6 +13,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [sources] DynamicPPL = {path = "../"} @@ -27,3 +28,4 @@ Mooncake = "0.4" PrettyTables = "2.4.0" ReverseDiff = "1.15.3" StableRNGs = "1" +Statistics = "1.11.1" diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 5400369a7..7b8d177da 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -3,6 +3,7 @@ using DynamicPPL.TestUtils.AD: run_ad, NoTest using Chairmarks: @be using PrettyTables: PrettyTables, ft_printf using StableRNGs: StableRNG +using Statistics: median rng = StableRNG(23) diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 95b5d6cd0..068e1f346 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -308,7 +308,7 @@ function run_ad( ) (median_grad, median_primal) else - nothing + nothing, nothing end return ADResult( From b0216ff0b9e6129ccfc24150cd3e90b04aeaa09a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Jul 2025 17:01:51 +0100 Subject: [PATCH 6/8] more fixes --- benchmarks/benchmarks.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index 7b8d177da..47a8a2912 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -1,4 +1,4 @@ -using DynamicPPLBenchmarks: Models, to_backend, make_varinfo, model_dimension +using DynamicPPLBenchmarks: Models, to_backend, make_varinfo using DynamicPPL.TestUtils.AD: run_ad, NoTest using Chairmarks: @be using PrettyTables: PrettyTables, ft_printf @@ -58,9 +58,9 @@ reference_time = begin median(@be Models.simple_assume_observe_non_model(obs)).time end -results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[] +results_table = Tuple{String,Int,String,String,Float64,Float64}[] -for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations +for (model_name, model, varinfo_choice, adbackend) in chosen_combinations @info "Running benchmark for $model_name" adtype = to_backend(adbackend) varinfo = make_varinfo(model, varinfo_choice) From 63bb81fb574c152fccb847c418150ac1aa3a5532 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 29 Jul 2025 17:15:46 +0100 Subject: [PATCH 7/8] fix --- benchmarks/src/DynamicPPLBenchmarks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 5c0ee930c..079891097 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -42,7 +42,7 @@ Available varinfo choices: The VarInfo is always linked. """ -function make_varinfo(model::Model, varinfo_choice::Symbol, adbackend::Symbol) +function make_varinfo(model::Model, varinfo_choice::Symbol) rng = StableRNG(23) vi = if varinfo_choice == :untyped From 4bbebb7b6f7567279dcbf4cd18bb5c22e0f038a0 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 4 Aug 2025 11:23:12 +0100 Subject: [PATCH 8/8] Bump deps --- benchmarks/Project.toml | 12 ++++++------ benchmarks/src/DynamicPPLBenchmarks.jl | 10 ++++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index 73ace2e56..137544e56 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -19,13 +19,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" DynamicPPL = {path = "../"} [compat] -ADTypes = "1.14.0" -Chairmarks = "1.3.1" -Distributions = "0.25.117" +ADTypes = "1" +Chairmarks = "1" +Distributions = "0.25" DynamicPPL = "0.37" ForwardDiff = "0.10.38, 1" Mooncake = "0.4" -PrettyTables = "2.4.0" -ReverseDiff = "1.15.3" +PrettyTables = "2" +ReverseDiff = "1" StableRNGs = "1" -Statistics = "1.11.1" +Statistics = "1" diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 079891097..d204c1c0d 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -19,7 +19,7 @@ export Models, to_backend, make_varinfo const SYMBOL_TO_BACKEND = Dict( :forwarddiff => ADTypes.AutoForwardDiff(), :reversediff => ADTypes.AutoReverseDiff(; compile=false), - :mooncake => ADTypes.AutoMooncake(; config=nothing), + :mooncake => ADTypes.AutoMooncake(), ) to_backend(x) = error("Unknown backend: $x") @@ -37,8 +37,8 @@ Create a VarInfo for the given `model` using the selected varinfo type. Available varinfo choices: • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` • `:typed` → uses `DynamicPPL.typed_varinfo(model)` - • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` - • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) + • `:simple_namedtuple` → builds a `SimpleVarInfo{Float64}(::NamedTuple)` + • `:simple_dict` → builds a `SimpleVarInfo{Float64}(::Dict)` The VarInfo is always linked. """ @@ -50,7 +50,9 @@ function make_varinfo(model::Model, varinfo_choice::Symbol) elseif varinfo_choice == :typed DynamicPPL.typed_varinfo(rng, model) elseif varinfo_choice == :simple_namedtuple - SimpleVarInfo{Float64}(model(rng)) + vi = DynamicPPL.typed_varinfo(rng, model) + vals = DynamicPPL.values_as(vi, NamedTuple) + SimpleVarInfo{Float64}(vals) elseif varinfo_choice == :simple_dict vi = DynamicPPL.typed_varinfo(rng, model) vals = DynamicPPL.values_as(vi, Dict)