Skip to content

Commit bbc82ef

Browse files
committed
Move code to BenchmarkToolsExt
1 parent bb59885 commit bbc82ef

File tree

5 files changed

+63
-66
lines changed

5 files changed

+63
-66
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2525
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2626

2727
[weakdeps]
28+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
2829
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2930
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3031
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -34,6 +35,7 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3435
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3536

3637
[extensions]
38+
DynamicPPLBenchmarkToolsExt = ["BenchmarkTools"]
3739
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
3840
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3941
DynamicPPLForwardDiffExt = ["ForwardDiff"]
@@ -47,6 +49,7 @@ AbstractMCMC = "5"
4749
AbstractPPL = "0.10.1"
4850
Accessors = "0.1"
4951
BangBang = "0.4.1"
52+
BenchmarkTools = "1.6.0"
5053
Bijectors = "0.13.18, 0.14, 0.15"
5154
ChainRulesCore = "1"
5255
Compat = "4"
File renamed without changes.

benchmarks/benchmarks.jl

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,42 @@ using Pkg
22
# To ensure we benchmark the local version of DynamicPPL, dev the folder above.
33
Pkg.develop(; path=joinpath(@__DIR__, ".."))
44

5-
using DynamicPPLBenchmarks: Models, make_suite, model_dimension
5+
using DynamicPPL: DynamicPPL, make_benchmark_suite, VarInfo
66
using BenchmarkTools: @benchmark, median, run
77
using PrettyTables: PrettyTables, ft_printf
8+
using ForwardDiff: ForwardDiff
9+
using Mooncake: Mooncake
10+
using ReverseDiff: ReverseDiff
811
using StableRNGs: StableRNG
912

10-
rng = StableRNG(23)
13+
include("Models.jl")
14+
15+
"""
16+
model_dimension(model, islinked)
17+
18+
Return the dimension of `model`, accounting for linking, if any.
19+
"""
20+
function model_dimension(model, islinked)
21+
vi = VarInfo()
22+
model(vi)
23+
if islinked
24+
vi = DynamicPPL.link(vi, model)
25+
end
26+
return length(vi[:])
27+
end
1128

1229
# Create DynamicPPL.Model instances to run benchmarks on.
13-
smorgasbord_instance = Models.smorgasbord(randn(rng, 100), randn(rng, 100))
30+
smorgasbord_instance = Models.smorgasbord(
31+
randn(StableRNG(23), 100), randn(StableRNG(23), 100)
32+
)
1433
loop_univariate1k, multivariate1k = begin
15-
data_1k = randn(rng, 1_000)
34+
data_1k = randn(StableRNG(23), 1_000)
1635
loop = Models.loop_univariate(length(data_1k)) | (; o=data_1k)
1736
multi = Models.multivariate(length(data_1k)) | (; o=data_1k)
1837
loop, multi
1938
end
2039
loop_univariate10k, multivariate10k = begin
21-
data_10k = randn(rng, 10_000)
40+
data_10k = randn(StableRNG(23), 10_000)
2241
loop = Models.loop_univariate(length(data_10k)) | (; o=data_10k)
2342
multi = Models.multivariate(length(data_10k)) | (; o=data_10k)
2443
loop, multi
@@ -34,7 +53,7 @@ end
3453
chosen_combinations = [
3554
(
3655
"Simple assume observe",
37-
Models.simple_assume_observe(randn(rng)),
56+
Models.simple_assume_observe(randn(StableRNG(23))),
3857
:typed,
3958
:forwarddiff,
4059
false,
@@ -50,22 +69,22 @@ chosen_combinations = [
5069
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true),
5170
("Multivariate 10k", multivariate10k, :typed, :mooncake, true),
5271
("Dynamic", Models.dynamic(), :typed, :mooncake, true),
53-
("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true),
72+
("Submodel", Models.parent(randn(StableRNG(23))), :typed, :mooncake, true),
5473
("LDA", lda_instance, :typed, :reversediff, true),
5574
]
5675

5776
# Time running a model-like function that does not use DynamicPPL, as a reference point.
5877
# Eval timings will be relative to this.
5978
reference_time = begin
60-
obs = randn(rng)
79+
obs = randn(StableRNG(23))
6180
median(@benchmark Models.simple_assume_observe_non_model(obs)).time
6281
end
6382

6483
results_table = Tuple{String,Int,String,String,Bool,Float64,Float64}[]
6584

6685
for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations
6786
@info "Running benchmark for $model_name"
68-
suite = make_suite(model, varinfo_choice, adbackend, islinked)
87+
suite = make_benchmark_suite(StableRNG(23), model, varinfo_choice, adbackend, islinked)
6988
results = run(suite)
7089
eval_time = median(results["evaluation"]).time
7190
relative_eval_time = eval_time / reference_time
Lines changed: 27 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,18 @@
1-
module DynamicPPLBenchmarks
1+
module DynamicPPLBenchmarkToolsExt
22

3-
using DynamicPPL: VarInfo, SimpleVarInfo, VarName
3+
using DynamicPPL:
4+
DynamicPPL, ADTypes, LogDensityProblems, Model, VarInfo, SimpleVarInfo, VarName
45
using BenchmarkTools: BenchmarkGroup, @benchmarkable
5-
using DynamicPPL: DynamicPPL
6-
using ADTypes: ADTypes
7-
using LogDensityProblems: LogDensityProblems
8-
9-
using ForwardDiff: ForwardDiff
10-
using Mooncake: Mooncake
11-
using ReverseDiff: ReverseDiff
12-
using StableRNGs: StableRNG
13-
14-
include("./Models.jl")
15-
using .Models: Models
16-
17-
export Models, make_suite, model_dimension
6+
using Random: Random
187

198
"""
20-
model_dimension(model, islinked)
21-
22-
Return the dimension of `model`, accounting for linking, if any.
23-
"""
24-
function model_dimension(model, islinked)
25-
vi = VarInfo()
26-
model(vi)
27-
if islinked
28-
vi = DynamicPPL.link(vi, model)
29-
end
30-
return length(vi[:])
31-
end
32-
33-
# Utility functions for representing AD backends using symbols.
34-
# Copied from TuringBenchmarking.jl.
35-
const SYMBOL_TO_BACKEND = Dict(
36-
:forwarddiff => ADTypes.AutoForwardDiff(),
37-
:reversediff => ADTypes.AutoReverseDiff(; compile=false),
38-
:reversediff_compiled => ADTypes.AutoReverseDiff(; compile=true),
39-
:mooncake => ADTypes.AutoMooncake(; config=nothing),
40-
)
41-
42-
to_backend(x) = error("Unknown backend: $x")
43-
to_backend(x::ADTypes.AbstractADType) = x
44-
function to_backend(x::Union{AbstractString,Symbol})
45-
k = Symbol(lowercase(string(x)))
46-
haskey(SYMBOL_TO_BACKEND, k) || error("Unknown backend: $x")
47-
return SYMBOL_TO_BACKEND[k]
48-
end
49-
50-
"""
51-
make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
9+
make_benchmark_suite(
10+
[rng::Random.AbstractRNG,]
11+
model::Model,
12+
varinfo_choice::Symbol,
13+
adtype::ADTypes.AbstractADType,
14+
islinked::Bool
15+
)
5216
5317
Create a benchmark suite for `model` using the selected varinfo type and AD backend.
5418
Available varinfo choices:
@@ -57,13 +21,15 @@ Available varinfo choices:
5721
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
5822
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
5923
60-
The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).
61-
6224
`islinked` determines whether to link the VarInfo for evaluation.
6325
"""
64-
function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
65-
rng = StableRNG(23)
66-
26+
function make_benchmark_suite(
27+
rng::Random.AbstractRNG,
28+
model::Model,
29+
varinfo_choice::Symbol,
30+
adtype::ADTypes.AbstractADType,
31+
islinked::Bool,
32+
)
6733
suite = BenchmarkGroup()
6834

6935
vi = if varinfo_choice == :untyped
@@ -82,14 +48,13 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
8248
error("Unknown varinfo choice: $varinfo_choice")
8349
end
8450

85-
adbackend = to_backend(adbackend)
8651
context = DynamicPPL.DefaultContext()
8752

8853
if islinked
8954
vi = DynamicPPL.link(vi, model)
9055
end
9156

92-
f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adbackend)
57+
f = DynamicPPL.LogDensityFunction(model, vi, context; adtype=adtype)
9358
# The parameters at which we evaluate f.
9459
θ = vi[:]
9560

@@ -102,5 +67,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
10267

10368
return suite
10469
end
70+
function make_benchmark_suite(
71+
model::Model, varinfo_choice::Symbol, adtype::Symbol, islinked::Bool
72+
)
73+
return make_benchmark_suite(
74+
Random.default_rng(), model, varinfo_choice, adtype, islinked
75+
)
76+
end
10577

106-
end # module
78+
end

src/DynamicPPL.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,11 @@ if isdefined(Base.Experimental, :register_error_hint)
217217
end
218218
end
219219

220-
# Standard tag: Improves stacktraces
221-
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
220+
# DynamicPPLForwardDiffExt
221+
# Improves stacktraces, see https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
222222
struct DynamicPPLTag end
223223

224+
# DynamicPPLBenchmarkToolsExt
225+
function make_benchmark_suite end
226+
224227
end # module

0 commit comments

Comments
 (0)