Skip to content

Commit a0289db

Browse files
authored
Improve API for AD testing (#964)
* Rework API for AD testing * Fix test * Add `rng` keyword argument * Use atol and rtol * remove unbound type parameter (?) * Don't need to do elementwise check * Update changelog * Fix typo
1 parent 57a53e1 commit a0289db

File tree

4 files changed

+149
-73
lines changed

4 files changed

+149
-73
lines changed

HISTORY.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@
88

99
The `@submodel` macro is fully removed; please use `to_submodel` instead.
1010

11+
### `DynamicPPL.TestUtils.AD.run_ad`
12+
13+
The three keyword arguments, `test`, `reference_backend`, and `expected_value_and_grad` have been merged into a single `test` keyword argument.
14+
Please see the API documentation for more details.
15+
(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.)
16+
17+
There is now also an `rng` keyword argument to help seed parameter generation.
18+
19+
Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
20+
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.
21+
1122
### Accumulators
1223

1324
This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:

docs/src/api.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL
206206

207207
```@docs
208208
DynamicPPL.TestUtils.AD.run_ad
209+
```
210+
211+
The default test setting is to compare against ForwardDiff.
212+
You can have more fine-grained control over how to test the AD backend using the following types:
213+
214+
```@docs
215+
DynamicPPL.TestUtils.AD.AbstractADCorrectnessTestSetting
216+
DynamicPPL.TestUtils.AD.WithBackend
217+
DynamicPPL.TestUtils.AD.WithExpectedResult
218+
DynamicPPL.TestUtils.AD.NoTest
219+
```
220+
221+
These are returned / thrown by the `run_ad` function:
222+
223+
```@docs
209224
DynamicPPL.TestUtils.AD.ADResult
210225
DynamicPPL.TestUtils.AD.ADIncorrectException
211226
```

src/test_utils/ad.jl

Lines changed: 114 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,57 @@ using ADTypes: AbstractADType, AutoForwardDiff
44
using Chairmarks: @be
55
import DifferentiationInterface as DI
66
using DocStringExtensions
7-
using DynamicPPL:
8-
Model,
9-
LogDensityFunction,
10-
VarInfo,
11-
AbstractVarInfo,
12-
link,
13-
DefaultContext,
14-
AbstractContext
7+
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
158
using LogDensityProblems: logdensity, logdensity_and_gradient
16-
using Random: Random, Xoshiro
9+
using Random: AbstractRNG, default_rng
1710
using Statistics: median
1811
using Test: @test
1912

20-
export ADResult, run_ad, ADIncorrectException
13+
export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest
2114

2215
"""
23-
REFERENCE_ADTYPE
16+
AbstractADCorrectnessTestSetting
2417
25-
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
26-
it's the default AD backend used in Turing.jl.
18+
Different ways of testing the correctness of an AD backend.
2719
"""
28-
const REFERENCE_ADTYPE = AutoForwardDiff()
20+
abstract type AbstractADCorrectnessTestSetting end
21+
22+
"""
23+
WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting
24+
25+
Test correctness by comparing it against the result obtained with `adtype`.
26+
27+
`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in
28+
Turing.jl.
29+
"""
30+
struct WithBackend{AD<:AbstractADType} <: AbstractADCorrectnessTestSetting
31+
adtype::AD
32+
end
33+
WithBackend() = WithBackend(AutoForwardDiff())
34+
35+
"""
36+
WithExpectedResult(
37+
value::T,
38+
grad::AbstractVector{T}
39+
) where {T <: AbstractFloat}
40+
<: AbstractADCorrectnessTestSetting
41+
42+
Test correctness by comparing it against a known result (e.g. one obtained
43+
analytically, or one obtained with a different backend previously). Both the
44+
value of the primal (i.e. the log-density) as well as its gradient must be
45+
supplied.
46+
"""
47+
struct WithExpectedResult{T<:AbstractFloat} <: AbstractADCorrectnessTestSetting
48+
value::T
49+
grad::AbstractVector{T}
50+
end
51+
52+
"""
53+
NoTest() <: AbstractADCorrectnessTestSetting
54+
55+
Disable correctness testing.
56+
"""
57+
struct NoTest <: AbstractADCorrectnessTestSetting end
2958

3059
"""
3160
ADIncorrectException{T<:AbstractFloat}
@@ -45,17 +74,18 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
4574
end
4675

4776
"""
48-
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
77+
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat}
4978
5079
Data structure to store the results of the AD correctness test.
5180
5281
The type parameter `Tparams` is the numeric type of the parameters passed in;
53-
`Tresult` is the type of the value and the gradient.
82+
`Tresult` is the type of the value and the gradient; and `Ttol` is the type of the
83+
absolute and relative tolerances used for correctness testing.
5484
5585
# Fields
5686
$(TYPEDFIELDS)
5787
"""
58-
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
88+
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat}
5989
"The DynamicPPL model that was tested"
6090
model::Model
6191
"The VarInfo that was used"
@@ -64,18 +94,18 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
6494
params::Vector{Tparams}
6595
"The AD backend that was tested"
6696
adtype::AbstractADType
67-
"The absolute tolerance for the value of logp"
68-
value_atol::Tresult
69-
"The absolute tolerance for the gradient of logp"
70-
grad_atol::Tresult
97+
"Absolute tolerance used for correctness test"
98+
atol::Ttol
99+
"Relative tolerance used for correctness test"
100+
rtol::Ttol
71101
"The expected value of logp"
72102
value_expected::Union{Nothing,Tresult}
73103
"The expected gradient of logp"
74104
grad_expected::Union{Nothing,Vector{Tresult}}
75105
"The value of logp (calculated using `adtype`)"
76-
value_actual::Union{Nothing,Tresult}
106+
value_actual::Tresult
77107
"The gradient of logp (calculated using `adtype`)"
78-
grad_actual::Union{Nothing,Vector{Tresult}}
108+
grad_actual::Vector{Tresult}
79109
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
80110
time_vs_primal::Union{Nothing,Tresult}
81111
end
@@ -84,14 +114,12 @@ end
84114
run_ad(
85115
model::Model,
86116
adtype::ADTypes.AbstractADType;
87-
test=true,
117+
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
88118
benchmark=false,
89-
value_atol=1e-6,
90-
grad_atol=1e-6,
119+
atol::AbstractFloat=1e-8,
120+
rtol::AbstractFloat=sqrt(eps()),
91121
varinfo::AbstractVarInfo=link(VarInfo(model), model),
92122
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
93-
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
94-
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
95123
verbose=true,
96124
)::ADResult
97125
@@ -133,8 +161,8 @@ Everything else is optional, and can be categorised into several groups:
133161
134162
Note that if the VarInfo is not specified (and thus automatically generated)
135163
the parameters in it will have been sampled from the prior of the model. If
136-
you want to seed the parameter generation, the easiest way is to pass a
137-
`rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
164+
you want to seed the parameter generation for the VarInfo, you can pass the
165+
`rng` keyword argument, which will then be used to create the VarInfo.
138166
139167
Finally, note that these only reflect the parameters used for _evaluating_
140168
the gradient. If you also want to control the parameters used for
@@ -143,25 +171,35 @@ Everything else is optional, and can be categorised into several groups:
143171
prep_params)`. You could then evaluate the gradient at a different set of
144172
parameters using the `params` keyword argument.
145173
146-
3. _How to specify the results to compare against._ (Only if `test=true`.)
174+
3. _How to specify the results to compare against._
147175
148176
Once logp and its gradient has been calculated with the specified `adtype`,
149-
it must be tested for correctness.
177+
it can optionally be tested for correctness. The exact way this is tested
178+
is specified in the `test` parameter.
179+
180+
There are several options for this:
150181
151-
This can be done either by specifying `reference_adtype`, in which case logp
152-
and its gradient will also be calculated with this reference in order to
153-
obtain the ground truth; or by using `expected_value_and_grad`, which is a
154-
tuple of `(logp, gradient)` that the calculated values must match. The
155-
latter is useful if you are testing multiple AD backends and want to avoid
156-
recalculating the ground truth multiple times.
182+
- You can explicitly specify the correct value using
183+
[`WithExpectedResult()`](@ref).
184+
- You can compare against the result obtained with a different AD backend
185+
using [`WithBackend(adtype)`](@ref).
186+
- You can disable testing by passing [`NoTest()`](@ref).
187+
- The default is to compare against the result obtained with ForwardDiff,
188+
i.e. `WithBackend(AutoForwardDiff())`.
189+
- `test=false` and `test=true` are synonyms for
190+
`NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.
157191
158-
The default reference backend is ForwardDiff. If none of these parameters are
159-
specified, ForwardDiff will be used to calculate the ground truth.
192+
4. _How to specify the tolerances._ (Only if testing is enabled.)
160193
161-
4. _How to specify the tolerances._ (Only if `test=true`.)
194+
Both absolute and relative tolerances can be specified using the `atol` and
195+
`rtol` keyword arguments respectively. The behaviour of these is similar to
196+
`isapprox()`, i.e. the value and gradient are considered correct if either
197+
atol or rtol is satisfied. The default values are `100*eps()` for `atol` and
198+
`sqrt(eps())` for `rtol`.
162199
163-
The tolerances for the value and gradient can be set using `value_atol` and
164-
`grad_atol`. These default to 1e-6.
200+
For the most part, it is the `rtol` check that is more meaningful, because
201+
we cannot know the magnitude of logp and its gradient a priori. The `atol`
202+
value is supplied to handle the case where gradients are equal to zero.
165203
166204
5. _Whether to output extra logging information._
167205
@@ -180,48 +218,58 @@ thrown as-is.
180218
function run_ad(
181219
model::Model,
182220
adtype::AbstractADType;
183-
test::Bool=true,
221+
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
184222
benchmark::Bool=false,
185-
value_atol::AbstractFloat=1e-6,
186-
grad_atol::AbstractFloat=1e-6,
187-
varinfo::AbstractVarInfo=link(VarInfo(model), model),
223+
atol::AbstractFloat=100 * eps(),
224+
rtol::AbstractFloat=sqrt(eps()),
225+
rng::AbstractRNG=default_rng(),
226+
varinfo::AbstractVarInfo=link(VarInfo(rng, model), model),
188227
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
189-
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
190-
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
191228
verbose=true,
192229
)::ADResult
230+
# Convert Boolean `test` to an AbstractADCorrectnessTestSetting
231+
if test isa Bool
232+
test = test ? WithBackend() : NoTest()
233+
end
234+
235+
# Extract parameters
193236
if isnothing(params)
194237
params = varinfo[:]
195238
end
196239
params = map(identity, params) # Concretise
197240

241+
# Calculate log-density and gradient with the backend of interest
198242
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
199243
verbose && println(" params : $(params)")
200244
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
201-
202245
value, grad = logdensity_and_gradient(ldf, params)
246+
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
203247
grad = collect(grad)
204248
verbose && println(" actual : $((value, grad))")
205249

206-
if test
207-
# Calculate ground truth to compare against
208-
value_true, grad_true = if expected_value_and_grad === nothing
209-
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
210-
logdensity_and_gradient(ldf_reference, params)
211-
else
212-
expected_value_and_grad
250+
# Test correctness
251+
if test isa NoTest
252+
value_true = nothing
253+
grad_true = nothing
254+
else
255+
# Get the correct result
256+
if test isa WithExpectedResult
257+
value_true = test.value
258+
grad_true = test.grad
259+
elseif test isa WithBackend
260+
ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype)
261+
value_true, grad_true = logdensity_and_gradient(ldf_reference, params)
262+
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
263+
grad_true = collect(grad_true)
213264
end
265+
# Perform testing
214266
verbose && println(" expected : $((value_true, grad_true))")
215-
grad_true = collect(grad_true)
216-
217267
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
218-
isapprox(value, value_true; atol=value_atol) || exc()
219-
isapprox(grad, grad_true; atol=grad_atol) || exc()
220-
else
221-
value_true = nothing
222-
grad_true = nothing
268+
isapprox(value, value_true; atol=atol, rtol=rtol) || exc()
269+
isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc()
223270
end
224271

272+
# Benchmark
225273
time_vs_primal = if benchmark
226274
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
227275
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
@@ -237,8 +285,8 @@ function run_ad(
237285
varinfo,
238286
params,
239287
adtype,
240-
value_atol,
241-
grad_atol,
288+
atol,
289+
rtol,
242290
value_true,
243291
grad_true,
244292
value,

test/ad.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DynamicPPL: LogDensityFunction
2+
using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
23

34
@testset "Automatic differentiation" begin
45
# Used as the ground truth that others are compared against.
@@ -31,9 +32,10 @@ using DynamicPPL: LogDensityFunction
3132
linked_varinfo = DynamicPPL.link(varinfo, m)
3233
f = LogDensityFunction(m, linked_varinfo)
3334
x = DynamicPPL.getparams(f)
35+
3436
# Calculate reference logp + gradient of logp using ForwardDiff
35-
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
36-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
37+
ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest())
38+
ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual
3739

3840
@testset "$adtype" for adtype in test_adtypes
3941
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
@@ -50,24 +52,24 @@ using DynamicPPL: LogDensityFunction
5052
if is_mooncake && is_1_11 && is_svi_vnv
5153
# https://github.com/compintell/Mooncake.jl/issues/470
5254
@test_throws ArgumentError DynamicPPL.LogDensityFunction(
53-
ref_ldf, adtype
55+
m, linked_varinfo; adtype=adtype
5456
)
5557
elseif is_mooncake && is_1_10 && is_svi_vnv
5658
# TODO: report upstream
5759
@test_throws UndefRefError DynamicPPL.LogDensityFunction(
58-
ref_ldf, adtype
60+
m, linked_varinfo; adtype=adtype
5961
)
6062
elseif is_mooncake && is_1_10 && is_svi_od
6163
# TODO: report upstream
6264
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction(
63-
ref_ldf, adtype
65+
m, linked_varinfo; adtype=adtype
6466
)
6567
else
66-
@test DynamicPPL.TestUtils.AD.run_ad(
68+
@test run_ad(
6769
m,
6870
adtype;
6971
varinfo=linked_varinfo,
70-
expected_value_and_grad=(ref_logp, ref_grad),
72+
test=WithExpectedResult(ref_logp, ref_grad),
7173
) isa Any
7274
end
7375
end

0 commit comments

Comments
 (0)