1
- module DynamicPPLBenchmarks
1
+ module DynamicPPLBenchmarkToolsExt
2
2
3
- using DynamicPPL: VarInfo, SimpleVarInfo, VarName
3
+ using DynamicPPL:
4
+ DynamicPPL, ADTypes, LogDensityProblems, Model, VarInfo, SimpleVarInfo, VarName
4
5
using BenchmarkTools: BenchmarkGroup, @benchmarkable
5
- using DynamicPPL: DynamicPPL
6
- using ADTypes: ADTypes
7
- using LogDensityProblems: LogDensityProblems
8
-
9
- using ForwardDiff: ForwardDiff
10
- using Mooncake: Mooncake
11
- using ReverseDiff: ReverseDiff
12
- using StableRNGs: StableRNG
13
-
14
- include (" ./Models.jl" )
15
- using . Models: Models
16
-
17
- export Models, make_suite, model_dimension
6
+ using Random: Random
18
7
19
8
"""
20
- model_dimension(model, islinked)
21
-
22
- Return the dimension of `model`, accounting for linking, if any.
23
- """
24
- function model_dimension (model, islinked)
25
- vi = VarInfo ()
26
- model (vi)
27
- if islinked
28
- vi = DynamicPPL. link (vi, model)
29
- end
30
- return length (vi[:])
31
- end
32
-
33
- # Utility functions for representing AD backends using symbols.
34
- # Copied from TuringBenchmarking.jl.
35
- const SYMBOL_TO_BACKEND = Dict (
36
- :forwarddiff => ADTypes. AutoForwardDiff (),
37
- :reversediff => ADTypes. AutoReverseDiff (; compile= false ),
38
- :reversediff_compiled => ADTypes. AutoReverseDiff (; compile= true ),
39
- :mooncake => ADTypes. AutoMooncake (; config= nothing ),
40
- )
41
-
42
- to_backend (x) = error (" Unknown backend: $x " )
43
- to_backend (x:: ADTypes.AbstractADType ) = x
44
- function to_backend (x:: Union{AbstractString,Symbol} )
45
- k = Symbol (lowercase (string (x)))
46
- haskey (SYMBOL_TO_BACKEND, k) || error (" Unknown backend: $x " )
47
- return SYMBOL_TO_BACKEND[k]
48
- end
49
-
50
- """
51
- make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool)
9
+ make_benchmark_suite(
10
+ [rng::Random.AbstractRNG,]
11
+ model::Model,
12
+ varinfo_choice::Symbol,
13
+ adtype::ADTypes.AbstractADType,
14
+ islinked::Bool
15
+ )
52
16
53
17
Create a benchmark suite for `model` using the selected varinfo type and AD backend.
54
18
Available varinfo choices:
@@ -57,13 +21,15 @@ Available varinfo choices:
57
21
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
58
22
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)
59
23
60
- The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`).
61
-
62
24
`islinked` determines whether to link the VarInfo for evaluation.
63
25
"""
64
- function make_suite (model, varinfo_choice:: Symbol , adbackend:: Symbol , islinked:: Bool )
65
- rng = StableRNG (23 )
66
-
26
+ function make_benchmark_suite (
27
+ rng:: Random.AbstractRNG ,
28
+ model:: Model ,
29
+ varinfo_choice:: Symbol ,
30
+ adtype:: ADTypes.AbstractADType ,
31
+ islinked:: Bool ,
32
+ )
67
33
suite = BenchmarkGroup ()
68
34
69
35
vi = if varinfo_choice == :untyped
@@ -82,14 +48,13 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
82
48
error (" Unknown varinfo choice: $varinfo_choice " )
83
49
end
84
50
85
- adbackend = to_backend (adbackend)
86
51
context = DynamicPPL. DefaultContext ()
87
52
88
53
if islinked
89
54
vi = DynamicPPL. link (vi, model)
90
55
end
91
56
92
- f = DynamicPPL. LogDensityFunction (model, vi, context; adtype= adbackend )
57
+ f = DynamicPPL. LogDensityFunction (model, vi, context; adtype= adtype )
93
58
# The parameters at which we evaluate f.
94
59
θ = vi[:]
95
60
@@ -102,5 +67,12 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
102
67
103
68
return suite
104
69
end
70
+ function make_benchmark_suite (
71
+ model:: Model , varinfo_choice:: Symbol , adtype:: Symbol , islinked:: Bool
72
+ )
73
+ return make_benchmark_suite (
74
+ Random. default_rng (), model, varinfo_choice, adtype, islinked
75
+ )
76
+ end
105
77
106
- end # module
78
+ end
0 commit comments