Skip to content

Commit 4ce84c2

Browse files
committed
Rework API for AD testing
1 parent 1882f72 commit 4ce84c2

File tree

4 files changed

+108
-46
lines changed

4 files changed

+108
-46
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88

99
The `@submodel` macro is fully removed; please use `to_submodel` instead.
1010

11+
### `DynamicPPL.TestUtils.AD.run_ad`
12+
13+
The three keyword arguments, `test`, `reference_backend`, and `expected_value_and_grad` have been merged into a single `test` keyword argument.
14+
Please see the API documentation for more details.
15+
(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.)
16+
1117
### Accumulators
1218

1319
This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:

docs/src/api.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,21 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL
211211

212212
```@docs
213213
DynamicPPL.TestUtils.AD.run_ad
214+
```
215+
216+
THe default test setting is to compare against ForwardDiff.
217+
You can have more fine-grained control over how to test the AD backend using the following types:
218+
219+
```@docs
220+
DynamicPPL.TestUtils.AD.AbstractADCorrectnessTestSetting
221+
DynamicPPL.TestUtils.AD.WithBackend
222+
DynamicPPL.TestUtils.AD.WithExpectedResult
223+
DynamicPPL.TestUtils.AD.NoTest
224+
```
225+
226+
These are returned / thrown by the `run_ad` function:
227+
228+
```@docs
214229
DynamicPPL.TestUtils.AD.ADResult
215230
DynamicPPL.TestUtils.AD.ADIncorrectException
216231
```

src/test_utils/ad.jl

Lines changed: 81 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,7 @@ using ADTypes: AbstractADType, AutoForwardDiff
44
using Chairmarks: @be
55
import DifferentiationInterface as DI
66
using DocStringExtensions
7-
using DynamicPPL:
8-
Model,
9-
LogDensityFunction,
10-
VarInfo,
11-
AbstractVarInfo,
12-
link,
13-
DefaultContext,
14-
AbstractContext
7+
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
158
using LogDensityProblems: logdensity, logdensity_and_gradient
169
using Random: Random, Xoshiro
1710
using Statistics: median
@@ -20,12 +13,48 @@ using Test: @test
2013
export ADResult, run_ad, ADIncorrectException
2114

2215
"""
23-
REFERENCE_ADTYPE
16+
AbstractADCorrectnessTestSetting
2417
25-
Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
26-
it's the default AD backend used in Turing.jl.
18+
Different ways of testing the correctness of an AD backend.
2719
"""
28-
const REFERENCE_ADTYPE = AutoForwardDiff()
20+
abstract type AbstractADCorrectnessTestSetting end
21+
22+
"""
23+
WithBackend(adtype::AbstractADType=AutoForwardDiff()) <: AbstractADCorrectnessTestSetting
24+
25+
Test correctness by comparing it against the result obtained with `adtype`.
26+
27+
`adtype` defaults to ForwardDiff.jl, since it's the default AD backend used in
28+
Turing.jl.
29+
"""
30+
struct WithBackend{AD<:AbstractADType} <: AbstractADCorrectnessTestSetting
31+
adtype::AD
32+
end
33+
WithBackend() = WithBackend(AutoForwardDiff())
34+
35+
"""
36+
WithExpectedResult(
37+
value::T,
38+
grad::AbstractVector{T}
39+
) where {T <: AbstractFloat}
40+
<: AbstractADCorrectnessTestSetting
41+
42+
Test correctness by comparing it against a known result (e.g. one obtained
43+
analytically, or one obtained with a different backend previously). Both the
44+
value of the primal (i.e. the log-density) as well as its gradient must be
45+
supplied.
46+
"""
47+
struct WithExpectedResult{T<:AbstractFloat} <: AbstractADCorrectnessTestSetting
48+
value::T
49+
grad::AbstractVector{T}
50+
end
51+
52+
"""
53+
NoTest() <: AbstractADCorrectnessTestSetting
54+
55+
Disable correctness testing.
56+
"""
57+
struct NoTest <: AbstractADCorrectnessTestSetting end
2958

3059
"""
3160
ADIncorrectException{T<:AbstractFloat}
@@ -84,14 +113,12 @@ end
84113
run_ad(
85114
model::Model,
86115
adtype::ADTypes.AbstractADType;
87-
test=true,
116+
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
88117
benchmark=false,
89118
value_atol=1e-6,
90119
grad_atol=1e-6,
91120
varinfo::AbstractVarInfo=link(VarInfo(model), model),
92121
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
93-
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
94-
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
95122
verbose=true,
96123
)::ADResult
97124
@@ -143,22 +170,25 @@ Everything else is optional, and can be categorised into several groups:
143170
prep_params)`. You could then evaluate the gradient at a different set of
144171
parameters using the `params` keyword argument.
145172
146-
3. _How to specify the results to compare against._ (Only if `test=true`.)
173+
3. _How to specify the results to compare against._
147174
148175
Once logp and its gradient has been calculated with the specified `adtype`,
149-
it must be tested for correctness.
176+
it can optionally be tested for correctness. The exact way this is tested
177+
is specified in the `test` parameter.
150178
151-
This can be done either by specifying `reference_adtype`, in which case logp
152-
and its gradient will also be calculated with this reference in order to
153-
obtain the ground truth; or by using `expected_value_and_grad`, which is a
154-
tuple of `(logp, gradient)` that the calculated values must match. The
155-
latter is useful if you are testing multiple AD backends and want to avoid
156-
recalculating the ground truth multiple times.
179+
There are several options for this:
157180
158-
The default reference backend is ForwardDiff. If none of these parameters are
159-
specified, ForwardDiff will be used to calculate the ground truth.
181+
- You can explicitly specify the correct value using
182+
[`WithExpectedResult()`](@ref).
183+
- You can compare against the result obtained with a different AD backend
184+
using [`WithBackend(adtype)`](@ref).
185+
- You can disable testing by passing [`NoTest()`](@ref).
186+
- The default is to compare against the result obtained with ForwardDiff,
187+
i.e. `WithBackend(AutoForwardDiff())`.
188+
- `test=false` and `test=true` are synonyms for
189+
`NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.
160190
161-
4. _How to specify the tolerances._ (Only if `test=true`.)
191+
4. _How to specify the tolerances._ (Only if testing is enabled.)
162192
163193
The tolerances for the value and gradient can be set using `value_atol` and
164194
`grad_atol`. These default to 1e-6.
@@ -180,48 +210,57 @@ thrown as-is.
180210
function run_ad(
181211
model::Model,
182212
adtype::AbstractADType;
183-
test::Bool=true,
213+
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
184214
benchmark::Bool=false,
185215
value_atol::AbstractFloat=1e-6,
186216
grad_atol::AbstractFloat=1e-6,
187217
varinfo::AbstractVarInfo=link(VarInfo(model), model),
188218
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
189-
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
190-
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
191219
verbose=true,
192220
)::ADResult
221+
# Convert Boolean `test` to an AbstractADCorrectnessTestSetting
222+
if test isa Bool
223+
test = test ? WithBackend() : NoTest()
224+
end
225+
226+
# Extract parameters
193227
if isnothing(params)
194228
params = varinfo[:]
195229
end
196230
params = map(identity, params) # Concretise
197231

232+
# Calculate log-density and gradient with the backend of interest
198233
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
199234
verbose && println(" params : $(params)")
200235
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
201-
202236
value, grad = logdensity_and_gradient(ldf, params)
237+
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
203238
grad = collect(grad)
204239
verbose && println(" actual : $((value, grad))")
205240

206-
if test
207-
# Calculate ground truth to compare against
208-
value_true, grad_true = if expected_value_and_grad === nothing
209-
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
210-
logdensity_and_gradient(ldf_reference, params)
211-
else
212-
expected_value_and_grad
241+
# Test correctness
242+
if test isa NoTest
243+
value_true = nothing
244+
grad_true = nothing
245+
else
246+
# Get the correct result
247+
if test isa WithExpectedResult
248+
value_true = test.value
249+
grad_true = test.grad
250+
elseif test isa WithBackend
251+
ldf_reference = LogDensityFunction(model, varinfo; adtype=test.adtype)
252+
value_true, grad_true = logdensity_and_gradient(ldf_reference, params)
253+
# collect(): https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754
254+
grad_true = collect(grad_true)
213255
end
256+
# Perform testing
214257
verbose && println(" expected : $((value_true, grad_true))")
215-
grad_true = collect(grad_true)
216-
217258
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
218259
isapprox(value, value_true; atol=value_atol) || exc()
219260
isapprox(grad, grad_true; atol=grad_atol) || exc()
220-
else
221-
value_true = nothing
222-
grad_true = nothing
223261
end
224262

263+
# Benchmark
225264
time_vs_primal = if benchmark
226265
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
227266
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])

test/ad.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using DynamicPPL: LogDensityFunction
2+
using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest
23

34
@testset "Automatic differentiation" begin
45
# Used as the ground truth that others are compared against.
@@ -31,9 +32,10 @@ using DynamicPPL: LogDensityFunction
3132
linked_varinfo = DynamicPPL.link(varinfo, m)
3233
f = LogDensityFunction(m, linked_varinfo)
3334
x = DynamicPPL.getparams(f)
35+
3436
# Calculate reference logp + gradient of logp using ForwardDiff
35-
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
36-
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
37+
ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest())
38+
ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual
3739

3840
@testset "$adtype" for adtype in test_adtypes
3941
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"
@@ -63,11 +65,11 @@ using DynamicPPL: LogDensityFunction
6365
ref_ldf, adtype
6466
)
6567
else
66-
@test DynamicPPL.TestUtils.AD.run_ad(
68+
@test run_ad(
6769
m,
6870
adtype;
6971
varinfo=linked_varinfo,
70-
expected_value_and_grad=(ref_logp, ref_grad),
72+
test=WithExpectedResult(ref_logp, ref_grad),
7173
) isa Any
7274
end
7375
end

0 commit comments

Comments
 (0)