-
Notifications
You must be signed in to change notification settings - Fork 36
Link varinfo by default in AD testing utilities; make test suite run on linked varinfos #890
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c5574ae
6d163fd
f17300a
c6c9595
41c32a0
3e6b2db
51a26b2
9bf0593
976885c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,19 +4,13 @@ | |
using Chairmarks: @be | ||
import DifferentiationInterface as DI | ||
using DocStringExtensions | ||
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo | ||
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link | ||
using LogDensityProblems: logdensity, logdensity_and_gradient | ||
using Random: Random, Xoshiro | ||
using Statistics: median | ||
using Test: @test | ||
|
||
export ADResult, run_ad | ||
|
||
# This function needed to work around the fact that different backends can | ||
# return different AbstractArrays for the gradient. See | ||
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more | ||
# context. | ||
_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x) | ||
export ADResult, run_ad, ADIncorrectException | ||
|
||
""" | ||
REFERENCE_ADTYPE | ||
|
@@ -27,33 +21,50 @@ | |
const REFERENCE_ADTYPE = AutoForwardDiff() | ||
|
||
""" | ||
ADResult | ||
ADIncorrectException{T<:Real} | ||
Exception thrown when an AD backend returns an incorrect value or gradient. | ||
The type parameter `T` is the numeric type of the value and gradient. | ||
""" | ||
struct ADIncorrectException{T<:Real} <: Exception | ||
value_expected::T | ||
value_actual::T | ||
grad_expected::Vector{T} | ||
grad_actual::Vector{T} | ||
end | ||
|
||
""" | ||
ADResult{Tparams<:Real,Tresult<:Real} | ||
Data structure to store the results of the AD correctness test. | ||
The type parameter `Tparams` is the numeric type of the parameters passed in; | ||
`Tresult` is the type of the value and the gradient. | ||
""" | ||
struct ADResult | ||
struct ADResult{Tparams<:Real,Tresult<:Real} | ||
"The DynamicPPL model that was tested" | ||
model::Model | ||
"The VarInfo that was used" | ||
varinfo::AbstractVarInfo | ||
"The values at which the model was evaluated" | ||
params::Vector{<:Real} | ||
params::Vector{Tparams} | ||
"The AD backend that was tested" | ||
adtype::AbstractADType | ||
"The absolute tolerance for the value of logp" | ||
value_atol::Real | ||
value_atol::Tresult | ||
"The absolute tolerance for the gradient of logp" | ||
grad_atol::Real | ||
grad_atol::Tresult | ||
"The expected value of logp" | ||
value_expected::Union{Nothing,Float64} | ||
value_expected::Union{Nothing,Tresult} | ||
"The expected gradient of logp" | ||
grad_expected::Union{Nothing,Vector{Float64}} | ||
grad_expected::Union{Nothing,Vector{Tresult}} | ||
"The value of logp (calculated using `adtype`)" | ||
value_actual::Union{Nothing,Real} | ||
value_actual::Union{Nothing,Tresult} | ||
"The gradient of logp (calculated using `adtype`)" | ||
grad_actual::Union{Nothing,Vector{Float64}} | ||
grad_actual::Union{Nothing,Vector{Tresult}} | ||
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" | ||
time_vs_primal::Union{Nothing,Float64} | ||
time_vs_primal::Union{Nothing,Tresult} | ||
end | ||
|
||
""" | ||
|
@@ -64,26 +75,27 @@ | |
benchmark=false, | ||
value_atol=1e-6, | ||
grad_atol=1e-6, | ||
varinfo::AbstractVarInfo=VarInfo(model), | ||
params::Vector{<:Real}=varinfo[:], | ||
varinfo::AbstractVarInfo=link(VarInfo(model), model), | ||
params::Union{Nothing,Vector{<:Real}}=nothing, | ||
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, | ||
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, | ||
verbose=true, | ||
)::ADResult | ||
### Description | ||
Test the correctness and/or benchmark the AD backend `adtype` for the model | ||
`model`. | ||
Whether to test and benchmark is controlled by the `test` and `benchmark` | ||
keyword arguments. By default, `test` is `true` and `benchmark` is `false`. | ||
Returns an [`ADResult`](@ref) object, which contains the results of the | ||
test and/or benchmark. | ||
Note that to run AD successfully you will need to import the AD backend itself. | ||
For example, to test with `AutoReverseDiff()` you will need to run `import | ||
ReverseDiff`. | ||
### Arguments | ||
There are two positional arguments, which absolutely must be provided: | ||
1. `model` - The model being tested. | ||
|
@@ -98,6 +110,11 @@ | |
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to | ||
using a `TypedVarInfo` generated from the model. | ||
It will also perform _linking_, that is, the parameters in the VarInfo will | ||
be transformed to unconstrained Euclidean space if they aren't already in | ||
that space. Note that the act of linking may change the length of the | ||
parameters. To disable linking, set `linked=false`. | ||
2. _How to specify the parameters._ | ||
For maximum control over this, generate a vector of parameters yourself and | ||
|
@@ -140,27 +157,40 @@ | |
By default, this function prints messages when it runs. To silence it, set | ||
`verbose=false`. | ||
### Returns / Throws | ||
Returns an [`ADResult`](@ref) object, which contains the results of the | ||
test and/or benchmark. | ||
If `test` is `true` and the AD backend returns an incorrect value or gradient, an | ||
`ADIncorrectException` is thrown. If a different error occurs, it will be | ||
thrown as-is. | ||
""" | ||
function run_ad( | ||
model::Model, | ||
adtype::AbstractADType; | ||
test=true, | ||
benchmark=false, | ||
value_atol=1e-6, | ||
grad_atol=1e-6, | ||
varinfo::AbstractVarInfo=VarInfo(model), | ||
params::Vector{<:Real}=varinfo[:], | ||
test::Bool=true, | ||
benchmark::Bool=false, | ||
value_atol::Real=1e-6, | ||
grad_atol::Real=1e-6, | ||
|
||
varinfo::AbstractVarInfo=link(VarInfo(model), model), | ||
params::Union{Nothing,Vector{<:Real}}=nothing, | ||
reference_adtype::AbstractADType=REFERENCE_ADTYPE, | ||
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, | ||
verbose=true, | ||
)::ADResult | ||
if isnothing(params) | ||
params = varinfo[:] | ||
end | ||
params = map(identity, params) # Concretise | ||
|
||
verbose && @info "Running AD on $(model.f) with $(adtype)\n" | ||
params = map(identity, params) | ||
verbose && println(" params : $(params)") | ||
ldf = LogDensityFunction(model, varinfo; adtype=adtype) | ||
|
||
value, grad = logdensity_and_gradient(ldf, params) | ||
grad = _to_vec_f64(grad) | ||
grad = collect(grad) | ||
sunxd3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
verbose && println(" actual : $((value, grad))") | ||
|
||
if test | ||
|
@@ -172,10 +202,11 @@ | |
expected_value_and_grad | ||
end | ||
verbose && println(" expected : $((value_true, grad_true))") | ||
grad_true = _to_vec_f64(grad_true) | ||
# Then compare | ||
@test isapprox(value, value_true; atol=value_atol) | ||
@test isapprox(grad, grad_true; atol=grad_atol) | ||
grad_true = collect(grad_true) | ||
|
||
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) | ||
isapprox(value, value_true; atol=value_atol) || exc() | ||
isapprox(grad, grad_true; atol=grad_atol) || exc() | ||
else | ||
value_true = nothing | ||
grad_true = nothing | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this paragraph should be removed from this version?
And instead say something like "by default, we'll use linked (explaining what "link" means) varinfo..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, this is my bad. I messed with the interface a few times and forgot to update this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed now!