@@ -4,28 +4,57 @@ using ADTypes: AbstractADType, AutoForwardDiff
4
4
using Chairmarks: @be
5
5
import DifferentiationInterface as DI
6
6
using DocStringExtensions
7
- using DynamicPPL:
8
- Model,
9
- LogDensityFunction,
10
- VarInfo,
11
- AbstractVarInfo,
12
- link,
13
- DefaultContext,
14
- AbstractContext
7
+ using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
15
8
using LogDensityProblems: logdensity, logdensity_and_gradient
16
- using Random: Random, Xoshiro
9
+ using Random: AbstractRNG, default_rng
17
10
using Statistics: median
18
11
using Test: @test
19
12
20
- export ADResult, run_ad, ADIncorrectException
13
+ export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest
21
14
22
15
"""
23
- REFERENCE_ADTYPE
16
+ AbstractADCorrectnessTestSetting
24
17
25
- Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
26
- it's the default AD backend used in Turing.jl.
18
+ Different ways of testing the correctness of an AD backend.
27
19
"""
28
- const REFERENCE_ADTYPE = AutoForwardDiff ()
20
+ abstract type AbstractADCorrectnessTestSetting end
21
+
22
+ """
23
+ WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting
24
+
25
+ Test correctness by comparing it against the result obtained with `adtype`.
26
+
27
+ `adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in
28
+ Turing.jl.
29
+ """
30
+ struct WithBackend{AD<: AbstractADType } <: AbstractADCorrectnessTestSetting
31
+ adtype:: AD
32
+ end
33
+ WithBackend () = WithBackend (AutoForwardDiff ())
34
+
35
+ """
36
+ WithExpectedResult(
37
+ value::T,
38
+ grad::AbstractVector{T}
39
+ ) where {T <: AbstractFloat}
40
+ <: AbstractADCorrectnessTestSetting
41
+
42
+ Test correctness by comparing it against a known result (e.g. one obtained
43
+ analytically, or one obtained with a different backend previously). Both the
44
+ value of the primal (i.e. the log-density) as well as its gradient must be
45
+ supplied.
46
+ """
47
+ struct WithExpectedResult{T<: AbstractFloat } <: AbstractADCorrectnessTestSetting
48
+ value:: T
49
+ grad:: AbstractVector{T}
50
+ end
51
+
52
+ """
53
+ NoTest() <: AbstractADCorrectnessTestSetting
54
+
55
+ Disable correctness testing.
56
+ """
57
+ struct NoTest <: AbstractADCorrectnessTestSetting end
29
58
30
59
"""
31
60
ADIncorrectException{T<:AbstractFloat}
@@ -45,17 +74,18 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
45
74
end
46
75
47
76
"""
48
- ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
77
+ ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat }
49
78
50
79
Data structure to store the results of the AD correctness test.
51
80
52
81
The type parameter `Tparams` is the numeric type of the parameters passed in;
53
- `Tresult` is the type of the value and the gradient.
82
+ `Tresult` is the type of the value and the gradient; and `Ttol` is the type of the
83
+ absolute and relative tolerances used for correctness testing.
54
84
55
85
# Fields
56
86
$(TYPEDFIELDS)
57
87
"""
58
- struct ADResult{Tparams<: AbstractFloat ,Tresult<: AbstractFloat }
88
+ struct ADResult{Tparams<: AbstractFloat ,Tresult<: AbstractFloat ,Ttol <: AbstractFloat }
59
89
" The DynamicPPL model that was tested"
60
90
model:: Model
61
91
" The VarInfo that was used"
@@ -64,18 +94,18 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
64
94
params:: Vector{Tparams}
65
95
" The AD backend that was tested"
66
96
adtype:: AbstractADType
67
- " The absolute tolerance for the value of logp "
68
- value_atol :: Tresult
69
- " The absolute tolerance for the gradient of logp "
70
- grad_atol :: Tresult
97
+ " Absolute tolerance used for correctness test "
98
+ atol :: Ttol
99
+ " Relative tolerance used for correctness test "
100
+ rtol :: Ttol
71
101
" The expected value of logp"
72
102
value_expected:: Union{Nothing,Tresult}
73
103
" The expected gradient of logp"
74
104
grad_expected:: Union{Nothing,Vector{Tresult}}
75
105
" The value of logp (calculated using `adtype`)"
76
- value_actual:: Union{Nothing, Tresult}
106
+ value_actual:: Tresult
77
107
" The gradient of logp (calculated using `adtype`)"
78
- grad_actual:: Union{Nothing, Vector{Tresult} }
108
+ grad_actual:: Vector{Tresult}
79
109
" If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
80
110
time_vs_primal:: Union{Nothing,Tresult}
81
111
end
84
114
run_ad(
85
115
model::Model,
86
116
adtype::ADTypes.AbstractADType;
87
- test=true ,
117
+ test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend() ,
88
118
benchmark=false,
89
- value_atol =1e-6 ,
90
- grad_atol=1e-6 ,
119
+ atol::AbstractFloat =1e-8 ,
120
+ rtol::AbstractFloat=sqrt(eps()) ,
91
121
varinfo::AbstractVarInfo=link(VarInfo(model), model),
92
122
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
93
- reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
94
- expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
95
123
verbose=true,
96
124
)::ADResult
97
125
@@ -133,8 +161,8 @@ Everything else is optional, and can be categorised into several groups:
133
161
134
162
Note that if the VarInfo is not specified (and thus automatically generated)
135
163
the parameters in it will have been sampled from the prior of the model. If
136
- you want to seed the parameter generation, the easiest way is to pass a
137
- `rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`) .
164
+ you want to seed the parameter generation for the VarInfo, you can pass the
165
+ `rng` keyword argument, which will then be used to create the VarInfo .
138
166
139
167
Finally, note that these only reflect the parameters used for _evaluating_
140
168
the gradient. If you also want to control the parameters used for
@@ -143,25 +171,35 @@ Everything else is optional, and can be categorised into several groups:
143
171
prep_params)`. You could then evaluate the gradient at a different set of
144
172
parameters using the `params` keyword argument.
145
173
146
- 3. _How to specify the results to compare against._ (Only if `test=true`.)
174
+ 3. _How to specify the results to compare against._
147
175
148
176
Once logp and its gradient has been calculated with the specified `adtype`,
149
- it must be tested for correctness.
177
+ it can optionally be tested for correctness. The exact way this is tested
178
+ is specified in the `test` parameter.
179
+
180
+ There are several options for this:
150
181
151
- This can be done either by specifying `reference_adtype`, in which case logp
152
- and its gradient will also be calculated with this reference in order to
153
- obtain the ground truth; or by using `expected_value_and_grad`, which is a
154
- tuple of `(logp, gradient)` that the calculated values must match. The
155
- latter is useful if you are testing multiple AD backends and want to avoid
156
- recalculating the ground truth multiple times.
182
+ - You can explicitly specify the correct value using
183
+ [`WithExpectedResult()`](@ref).
184
+ - You can compare against the result obtained with a different AD backend
185
+ using [`WithBackend(adtype)`](@ref).
186
+ - You can disable testing by passing [`NoTest()`](@ref).
187
+ - The default is to compare against the result obtained with ForwardDiff,
188
+ i.e. `WithBackend(AutoForwardDiff())`.
189
+ - `test=false` and `test=true` are synonyms for
190
+ `NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.
157
191
158
- The default reference backend is ForwardDiff. If none of these parameters are
159
- specified, ForwardDiff will be used to calculate the ground truth.
192
+ 4. _How to specify the tolerances._ (Only if testing is enabled.)
160
193
161
- 4. _How to specify the tolerances._ (Only if `test=true`.)
194
+ Both absolute and relative tolerances can be specified using the `atol` and
195
+ `rtol` keyword arguments respectively. The behaviour of these is similar to
196
+ `isapprox()`, i.e. the value and gradient are considered correct if either
197
+ atol or rtol is satisfied. The default values are `100*eps()` for `atol` and
198
+ `sqrt(eps())` for `rtol`.
162
199
163
- The tolerances for the value and gradient can be set using `value_atol` and
164
- `grad_atol`. These default to 1e-6.
200
+ For the most part, it is the `rtol` check that is more meaningful, because
201
+ we cannot know the magnitude of logp and its gradient a priori. The `atol`
202
+ value is supplied to handle the case where gradients are equal to zero.
165
203
166
204
5. _Whether to output extra logging information._
167
205
@@ -180,48 +218,58 @@ thrown as-is.
180
218
function run_ad (
181
219
model:: Model ,
182
220
adtype:: AbstractADType ;
183
- test:: Bool = true ,
221
+ test:: Union{AbstractADCorrectnessTestSetting, Bool} = WithBackend () ,
184
222
benchmark:: Bool = false ,
185
- value_atol:: AbstractFloat = 1e-6 ,
186
- grad_atol:: AbstractFloat = 1e-6 ,
187
- varinfo:: AbstractVarInfo = link (VarInfo (model), model),
223
+ atol:: AbstractFloat = 100 * eps (),
224
+ rtol:: AbstractFloat = sqrt (eps ()),
225
+ rng:: AbstractRNG = default_rng (),
226
+ varinfo:: AbstractVarInfo = link (VarInfo (rng, model), model),
188
227
params:: Union{Nothing,Vector{<:AbstractFloat}} = nothing ,
189
- reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
190
- expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}} = nothing ,
191
228
verbose= true ,
192
229
):: ADResult
230
+ # Convert Boolean `test` to an AbstractADCorrectnessTestSetting
231
+ if test isa Bool
232
+ test = test ? WithBackend () : NoTest ()
233
+ end
234
+
235
+ # Extract parameters
193
236
if isnothing (params)
194
237
params = varinfo[:]
195
238
end
196
239
params = map (identity, params) # Concretise
197
240
241
+ # Calculate log-density and gradient with the backend of interest
198
242
verbose && @info " Running AD on $(model. f) with $(adtype) \n "
199
243
verbose && println (" params : $(params) " )
200
244
ldf = LogDensityFunction (model, varinfo; adtype= adtype)
201
-
202
245
value, grad = logdensity_and_gradient (ldf, params)
246
+ # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
203
247
grad = collect (grad)
204
248
verbose && println (" actual : $((value, grad)) " )
205
249
206
- if test
207
- # Calculate ground truth to compare against
208
- value_true, grad_true = if expected_value_and_grad === nothing
209
- ldf_reference = LogDensityFunction (model, varinfo; adtype= reference_adtype)
210
- logdensity_and_gradient (ldf_reference, params)
211
- else
212
- expected_value_and_grad
250
+ # Test correctness
251
+ if test isa NoTest
252
+ value_true = nothing
253
+ grad_true = nothing
254
+ else
255
+ # Get the correct result
256
+ if test isa WithExpectedResult
257
+ value_true = test. value
258
+ grad_true = test. grad
259
+ elseif test isa WithBackend
260
+ ldf_reference = LogDensityFunction (model, varinfo; adtype= test. adtype)
261
+ value_true, grad_true = logdensity_and_gradient (ldf_reference, params)
262
+ # collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
263
+ grad_true = collect (grad_true)
213
264
end
265
+ # Perform testing
214
266
verbose && println (" expected : $((value_true, grad_true)) " )
215
- grad_true = collect (grad_true)
216
-
217
267
exc () = throw (ADIncorrectException (value, value_true, grad, grad_true))
218
- isapprox (value, value_true; atol= value_atol) || exc ()
219
- isapprox (grad, grad_true; atol= grad_atol) || exc ()
220
- else
221
- value_true = nothing
222
- grad_true = nothing
268
+ isapprox (value, value_true; atol= atol, rtol= rtol) || exc ()
269
+ isapprox (grad, grad_true; atol= atol, rtol= rtol) || exc ()
223
270
end
224
271
272
+ # Benchmark
225
273
time_vs_primal = if benchmark
226
274
primal_benchmark = @be (ldf, params) logdensity (_[1 ], _[2 ])
227
275
grad_benchmark = @be (ldf, params) logdensity_and_gradient (_[1 ], _[2 ])
@@ -237,8 +285,8 @@ function run_ad(
237
285
varinfo,
238
286
params,
239
287
adtype,
240
- value_atol ,
241
- grad_atol ,
288
+ atol ,
289
+ rtol ,
242
290
value_true,
243
291
grad_true,
244
292
value,
0 commit comments