@@ -4,14 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff
4
4
using Chairmarks: @be
5
5
import DifferentiationInterface as DI
6
6
using DocStringExtensions
7
- using DynamicPPL:
8
- Model,
9
- LogDensityFunction,
10
- VarInfo,
11
- AbstractVarInfo,
12
- link,
13
- DefaultContext,
14
- AbstractContext
7
+ using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
15
8
using LogDensityProblems: logdensity, logdensity_and_gradient
16
9
using Random: Random, Xoshiro
17
10
using Statistics: median
@@ -20,12 +13,48 @@ using Test: @test
20
13
export ADResult, run_ad, ADIncorrectException
21
14
22
15
"""
23
- REFERENCE_ADTYPE
16
+ AbstractADCorrectnessTestSetting
24
17
25
- Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
26
- it's the default AD backend used in Turing.jl.
18
+ Different ways of testing the correctness of an AD backend.
27
19
"""
28
- const REFERENCE_ADTYPE = AutoForwardDiff ()
20
+ abstract type AbstractADCorrectnessTestSetting end
21
+
22
+ """
23
+ WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting
24
+
25
+ Test correctness by comparing it against the result obtained with `adtype`.
26
+
27
+ `adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in
28
+ Turing.jl.
29
+ """
30
+ struct WithBackend{AD<: AbstractADType } <: AbstractADCorrectnessTestSetting
31
+ adtype:: AD
32
+ end
33
+ WithBackend () = WithBackend (AutoForwardDiff ())
34
+
35
+ """
36
+ WithExpectedResult(
37
+ value::T,
38
+ grad::AbstractVector{T}
39
+ ) where {T <: AbstractFloat}
40
+ <: AbstractADCorrectnessTestSetting
41
+
42
+ Test correctness by comparing it against a known result (e.g. one obtained
43
+ analytically, or one obtained with a different backend previously). Both the
44
+ value of the primal (i.e. the log-density) as well as its gradient must be
45
+ supplied.
46
+ """
47
+ struct WithExpectedResult{T<: AbstractFloat } <: AbstractADCorrectnessTestSetting
48
+ value:: T
49
+ grad:: AbstractVector{T}
50
+ end
51
+
52
+ """
53
+ NoTest() <: AbstractADCorrectnessTestSetting
54
+
55
+ Disable correctness testing.
56
+ """
57
+ struct NoTest <: AbstractADCorrectnessTestSetting end
29
58
30
59
"""
31
60
ADIncorrectException{T<:AbstractFloat}
84
113
run_ad(
85
114
model::Model,
86
115
adtype::ADTypes.AbstractADType;
87
- test=true ,
116
+ test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend() ,
88
117
benchmark=false,
89
118
value_atol=1e-6,
90
119
grad_atol=1e-6,
91
120
varinfo::AbstractVarInfo=link(VarInfo(model), model),
92
121
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
93
- reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
94
- expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
95
122
verbose=true,
96
123
)::ADResult
97
124
@@ -143,22 +170,25 @@ Everything else is optional, and can be categorised into several groups:
143
170
prep_params)`. You could then evaluate the gradient at a different set of
144
171
parameters using the `params` keyword argument.
145
172
146
- 3. _How to specify the results to compare against._ (Only if `test=true`.)
173
+ 3. _How to specify the results to compare against._
147
174
148
175
Once logp and its gradient has been calculated with the specified `adtype`,
149
- it must be tested for correctness.
176
+ it can optionally be tested for correctness. The exact way this is tested
177
+ is specified in the `test` parameter.
150
178
151
- This can be done either by specifying `reference_adtype`, in which case logp
152
- and its gradient will also be calculated with this reference in order to
153
- obtain the ground truth; or by using `expected_value_and_grad`, which is a
154
- tuple of `(logp, gradient)` that the calculated values must match. The
155
- latter is useful if you are testing multiple AD backends and want to avoid
156
- recalculating the ground truth multiple times.
179
+ There are several options for this:
157
180
158
- The default reference backend is ForwardDiff. If none of these parameters are
159
- specified, ForwardDiff will be used to calculate the ground truth.
181
+ - You can explicitly specify the correct value using
182
+ [`WithExpectedResult()`](@ref).
183
+ - You can compare against the result obtained with a different AD backend
184
+ using [`WithBackend(adtype)`](@ref).
185
+ - You can disable testing by passing [`NoTest()`](@ref).
186
+ - The default is to compare against the result obtained with ForwardDiff,
187
+ i.e. `WithBackend(AutoForwardDiff())`.
188
+ - `test=false` and `test=true` are synonyms for
189
+ `NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.
160
190
161
- 4. _How to specify the tolerances._ (Only if `test=true` .)
191
+ 4. _How to specify the tolerances._ (Only if testing is enabled .)
162
192
163
193
The tolerances for the value and gradient can be set using `value_atol` and
164
194
`grad_atol`. These default to 1e-6.
@@ -180,48 +210,57 @@ thrown as-is.
180
210
function run_ad (
181
211
model:: Model ,
182
212
adtype:: AbstractADType ;
183
- test:: Bool = true ,
213
+ test:: Union{AbstractADCorrectnessTestSetting, Bool} = WithBackend () ,
184
214
benchmark:: Bool = false ,
185
215
value_atol:: AbstractFloat = 1e-6 ,
186
216
grad_atol:: AbstractFloat = 1e-6 ,
187
217
varinfo:: AbstractVarInfo = link (VarInfo (model), model),
188
218
params:: Union{Nothing,Vector{<:AbstractFloat}} = nothing ,
189
- reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
190
- expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}} = nothing ,
191
219
verbose= true ,
192
220
):: ADResult
221
+ # Convert Boolean `test` to an AbstractADCorrectnessTestSetting
222
+ if test isa Bool
223
+ test = test ? WithBackend () : NoTest ()
224
+ end
225
+
226
+ # Extract parameters
193
227
if isnothing (params)
194
228
params = varinfo[:]
195
229
end
196
230
params = map (identity, params) # Concretise
197
231
232
+ # Calculate log-density and gradient with the backend of interest
198
233
verbose && @info " Running AD on $(model. f) with $(adtype) \n "
199
234
verbose && println (" params : $(params) " )
200
235
ldf = LogDensityFunction (model, varinfo; adtype= adtype)
201
-
202
236
value, grad = logdensity_and_gradient (ldf, params)
237
+ # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
203
238
grad = collect (grad)
204
239
verbose && println (" actual : $((value, grad)) " )
205
240
206
- if test
207
- # Calculate ground truth to compare against
208
- value_true, grad_true = if expected_value_and_grad === nothing
209
- ldf_reference = LogDensityFunction (model, varinfo; adtype= reference_adtype)
210
- logdensity_and_gradient (ldf_reference, params)
211
- else
212
- expected_value_and_grad
241
+ # Test correctness
242
+ if test isa NoTest
243
+ value_true = nothing
244
+ grad_true = nothing
245
+ else
246
+ # Get the correct result
247
+ if test isa WithExpectedResult
248
+ value_true = test. value
249
+ grad_true = test. grad
250
+ elseif test isa WithBackend
251
+ ldf_reference = LogDensityFunction (model, varinfo; adtype= test. adtype)
252
+ value_true, grad_true = logdensity_and_gradient (ldf_reference, params)
253
+ # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
254
+ grad_true = collect (grad_true)
213
255
end
256
+ # Perform testing
214
257
verbose && println (" expected : $((value_true, grad_true)) " )
215
- grad_true = collect (grad_true)
216
-
217
258
exc () = throw (ADIncorrectException (value, value_true, grad, grad_true))
218
259
isapprox (value, value_true; atol= value_atol) || exc ()
219
260
isapprox (grad, grad_true; atol= grad_atol) || exc ()
220
- else
221
- value_true = nothing
222
- grad_true = nothing
223
261
end
224
262
263
+ # Benchmark
225
264
time_vs_primal = if benchmark
226
265
primal_benchmark = @be (ldf, params) logdensity (_[1 ], _[2 ])
227
266
grad_benchmark = @be (ldf, params) logdensity_and_gradient (_[1 ], _[2 ])
0 commit comments