Skip to content

Commit c6c9595

Browse files
committed
Fix interface so that callers can inspect results
1 parent f17300a commit c6c9595

File tree

1 file changed

+54
-37
lines changed

1 file changed

+54
-37
lines changed

src/test_utils/ad.jl

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,7 @@ using Random: Random, Xoshiro
1010
using Statistics: median
1111
using Test: @test
1212

13-
export ADResult, run_ad
14-
15-
# This function needed to work around the fact that different backends can
16-
# return different AbstractArrays for the gradient. See
17-
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
18-
# context.
19-
_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x)
13+
export ADResult, run_ad, ADIncorrectException
2014

2115
"""
2216
REFERENCE_ADTYPE
@@ -27,33 +21,50 @@ it's the default AD backend used in Turing.jl.
2721
const REFERENCE_ADTYPE = AutoForwardDiff()
2822

2923
"""
30-
ADResult
24+
ADIncorrectException{T<:Real}
25+
26+
Exception thrown when an AD backend returns an incorrect value or gradient.
27+
28+
The type parameter `T` is the numeric type of the value and gradient.
29+
"""
30+
struct ADIncorrectException{T<:Real} <: Exception
31+
value_expected::T
32+
value_actual::T
33+
grad_expected::Vector{T}
34+
grad_actual::Vector{T}
35+
end
36+
37+
"""
38+
ADResult{Tparams<:Real,Tresult<:Real}
3139
3240
Data structure to store the results of the AD correctness test.
41+
42+
The type parameter `Tparams` is the numeric type of the parameters passed in;
43+
`Tresult` is the type of the value and the gradient.
3344
"""
34-
struct ADResult
45+
struct ADResult{Tparams<:Real,Tresult<:Real}
3546
"The DynamicPPL model that was tested"
3647
model::Model
3748
"The VarInfo that was used"
3849
varinfo::AbstractVarInfo
3950
"The values at which the model was evaluated"
40-
params::Vector{<:Real}
51+
params::Vector{Tparams}
4152
"The AD backend that was tested"
4253
adtype::AbstractADType
4354
"The absolute tolerance for the value of logp"
44-
value_atol::Real
55+
value_atol::Tresult
4556
"The absolute tolerance for the gradient of logp"
46-
grad_atol::Real
57+
grad_atol::Tresult
4758
"The expected value of logp"
48-
value_expected::Union{Nothing,Float64}
59+
value_expected::Union{Nothing,Tresult}
4960
"The expected gradient of logp"
50-
grad_expected::Union{Nothing,Vector{Float64}}
61+
grad_expected::Union{Nothing,Vector{Tresult}}
5162
"The value of logp (calculated using `adtype`)"
52-
value_actual::Union{Nothing,Real}
63+
value_actual::Union{Nothing,Tresult}
5364
"The gradient of logp (calculated using `adtype`)"
54-
grad_actual::Union{Nothing,Vector{Float64}}
65+
grad_actual::Union{Nothing,Vector{Tresult}}
5566
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
56-
time_vs_primal::Union{Nothing,Float64}
67+
time_vs_primal::Union{Nothing,Tresult}
5768
end
5869

5970
"""
@@ -64,27 +75,27 @@ end
6475
benchmark=false,
6576
value_atol=1e-6,
6677
grad_atol=1e-6,
67-
linked::Bool=true,
68-
varinfo::AbstractVarInfo=VarInfo(model),
78+
varinfo::AbstractVarInfo=link(VarInfo(model), model),
6979
params::Union{Nothing,Vector{<:Real}}=nothing,
7080
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
7181
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
7282
verbose=true,
7383
)::ADResult
7484
85+
### Description
86+
7587
Test the correctness and/or benchmark the AD backend `adtype` for the model
7688
`model`.
7789
7890
Whether to test and benchmark is controlled by the `test` and `benchmark`
7991
keyword arguments. By default, `test` is `true` and `benchmark` is `false`.
8092
81-
Returns an [`ADResult`](@ref) object, which contains the results of the
82-
test and/or benchmark.
83-
8493
Note that to run AD successfully you will need to import the AD backend itself.
8594
For example, to test with `AutoReverseDiff()` you will need to run `import
8695
ReverseDiff`.
8796
97+
### Arguments
98+
8899
There are two positional arguments, which absolutely must be provided:
89100
90101
1. `model` - The model being tested.
@@ -146,35 +157,40 @@ Everything else is optional, and can be categorised into several groups:
146157
147158
By default, this function prints messages when it runs. To silence it, set
148159
`verbose=false`.
160+
161+
### Returns / Throws
162+
163+
Returns an [`ADResult`](@ref) object, which contains the results of the
164+
test and/or benchmark.
165+
166+
If `test` is `true` and the AD backend returns an incorrect value or gradient, an
167+
`ADIncorrectException` is thrown. If a different error occurs, it will be
168+
thrown as-is.
149169
"""
150170
function run_ad(
151171
model::Model,
152172
adtype::AbstractADType;
153-
test=true,
154-
benchmark=false,
155-
value_atol=1e-6,
156-
grad_atol=1e-6,
157-
linked::Bool=true,
158-
varinfo::AbstractVarInfo=VarInfo(model),
173+
test::Bool=true,
174+
benchmark::Bool=false,
175+
value_atol::Real=1e-6,
176+
grad_atol::Real=1e-6,
177+
varinfo::AbstractVarInfo=link(VarInfo(model), model),
159178
params::Union{Nothing,Vector{<:Real}}=nothing,
160179
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
161180
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
162181
verbose=true,
163182
)::ADResult
164-
if linked
165-
varinfo = link(varinfo, model)
166-
end
167183
if isnothing(params)
168184
params = varinfo[:]
169185
end
170-
params = map(identity, params)
186+
params = map(identity, params) # Concretise
171187

172188
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
173189
verbose && println(" params : $(params)")
174190
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
175191

176192
value, grad = logdensity_and_gradient(ldf, params)
177-
grad = _to_vec_f64(grad)
193+
grad = collect(grad)
178194
verbose && println(" actual : $((value, grad))")
179195

180196
if test
@@ -186,10 +202,11 @@ function run_ad(
186202
expected_value_and_grad
187203
end
188204
verbose && println(" expected : $((value_true, grad_true))")
189-
grad_true = _to_vec_f64(grad_true)
190-
# Then compare
191-
@test isapprox(value, value_true; atol=value_atol)
192-
@test isapprox(grad, grad_true; atol=grad_atol)
205+
grad_true = collect(grad_true)
206+
207+
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
208+
isapprox(value, value_true; atol=value_atol) || exc()
209+
isapprox(grad, grad_true; atol=grad_atol) || exc()
193210
else
194211
value_true = nothing
195212
grad_true = nothing

0 commit comments

Comments
 (0)