Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@

**Breaking changes**

### AD testing utilities

`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
To disable this, pass the `linked=false` keyword argument.
If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure.
This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information.
From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`.

### SimpleVarInfo linking / invlinking

Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error.

### VarInfo constructors

`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL
```@docs
DynamicPPL.TestUtils.AD.run_ad
DynamicPPL.TestUtils.AD.ADResult
DynamicPPL.TestUtils.AD.ADIncorrectException
```

## Demo models
Expand Down
101 changes: 66 additions & 35 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,13 @@
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Statistics: median
using Test: @test

export ADResult, run_ad

# This function needed to work around the fact that different backends can
# return different AbstractArrays for the gradient. See
# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
# context.
_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x)
export ADResult, run_ad, ADIncorrectException

"""
REFERENCE_ADTYPE
Expand All @@ -27,33 +21,50 @@
const REFERENCE_ADTYPE = AutoForwardDiff()

"""
ADResult
ADIncorrectException{T<:Real}
Exception thrown when an AD backend returns an incorrect value or gradient.
The type parameter `T` is the numeric type of the value and gradient.
"""
struct ADIncorrectException{T<:Real} <: Exception
value_expected::T
value_actual::T
grad_expected::Vector{T}
grad_actual::Vector{T}
end

"""
ADResult{Tparams<:Real,Tresult<:Real}
Data structure to store the results of the AD correctness test.
The type parameter `Tparams` is the numeric type of the parameters passed in;
`Tresult` is the type of the value and the gradient.
"""
struct ADResult
struct ADResult{Tparams<:Real,Tresult<:Real}
"The DynamicPPL model that was tested"
model::Model
"The VarInfo that was used"
varinfo::AbstractVarInfo
"The values at which the model was evaluated"
params::Vector{<:Real}
params::Vector{Tparams}
"The AD backend that was tested"
adtype::AbstractADType
"The absolute tolerance for the value of logp"
value_atol::Real
value_atol::Tresult
"The absolute tolerance for the gradient of logp"
grad_atol::Real
grad_atol::Tresult
"The expected value of logp"
value_expected::Union{Nothing,Float64}
value_expected::Union{Nothing,Tresult}
"The expected gradient of logp"
grad_expected::Union{Nothing,Vector{Float64}}
grad_expected::Union{Nothing,Vector{Tresult}}
"The value of logp (calculated using `adtype`)"
value_actual::Union{Nothing,Real}
value_actual::Union{Nothing,Tresult}
"The gradient of logp (calculated using `adtype`)"
grad_actual::Union{Nothing,Vector{Float64}}
grad_actual::Union{Nothing,Vector{Tresult}}
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
time_vs_primal::Union{Nothing,Float64}
time_vs_primal::Union{Nothing,Tresult}
end

"""
Expand All @@ -64,26 +75,27 @@
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:Real}}=nothing,
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
verbose=true,
)::ADResult
### Description
Test the correctness and/or benchmark the AD backend `adtype` for the model
`model`.
Whether to test and benchmark is controlled by the `test` and `benchmark`
keyword arguments. By default, `test` is `true` and `benchmark` is `false`.
Returns an [`ADResult`](@ref) object, which contains the results of the
test and/or benchmark.
Note that to run AD successfully you will need to import the AD backend itself.
For example, to test with `AutoReverseDiff()` you will need to run `import
ReverseDiff`.
### Arguments
There are two positional arguments, which absolutely must be provided:
1. `model` - The model being tested.
Expand All @@ -98,6 +110,11 @@
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
using a `TypedVarInfo` generated from the model.
It will also perform _linking_, that is, the parameters in the VarInfo will
be transformed to unconstrained Euclidean space if they aren't already in
that space. Note that the act of linking may change the length of the
parameters. To disable linking, set `linked=false`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this paragraph should be removed from this version?
And instead say something like "by default, we'll use linked (explaining what "link" means) varinfo..."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, this is my bad. I messed with the interface a few times and forgot to update this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed now!

2. _How to specify the parameters._
For maximum control over this, generate a vector of parameters yourself and
Expand Down Expand Up @@ -140,27 +157,40 @@
By default, this function prints messages when it runs. To silence it, set
`verbose=false`.
### Returns / Throws
Returns an [`ADResult`](@ref) object, which contains the results of the
test and/or benchmark.
If `test` is `true` and the AD backend returns an incorrect value or gradient, an
`ADIncorrectException` is thrown. If a different error occurs, it will be
thrown as-is.
"""
function run_ad(
model::Model,
adtype::AbstractADType;
test=true,
benchmark=false,
value_atol=1e-6,
grad_atol=1e-6,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
test::Bool=true,
benchmark::Bool=false,
value_atol::Real=1e-6,
grad_atol::Real=1e-6,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might as well be Float(even Float64)? my opinion is not strong here, just a mention.
Also could value_atol and grad_atol have different types?

Copy link
Member Author

@penelopeysm penelopeysm Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, yes, I think some of these Reals should be AbstractFloats? I think f64 might be a bit too restrictive, I guess it's pretty much always going to be f64 in regular usage, but since this is an exported interface I figured it should be generic.

varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:Real}}=nothing,
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
verbose=true,
)::ADResult
if isnothing(params)
params = varinfo[:]

Check warning on line 184 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L183-L184

Added lines #L183 - L184 were not covered by tests
end
params = map(identity, params) # Concretise

Check warning on line 186 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L186

Added line #L186 was not covered by tests

verbose && @info "Running AD on $(model.f) with $(adtype)\n"
params = map(identity, params)
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)

value, grad = logdensity_and_gradient(ldf, params)
grad = _to_vec_f64(grad)
grad = collect(grad)

Check warning on line 193 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L193

Added line #L193 was not covered by tests
verbose && println(" actual : $((value, grad))")

if test
Expand All @@ -172,10 +202,11 @@
expected_value_and_grad
end
verbose && println(" expected : $((value_true, grad_true))")
grad_true = _to_vec_f64(grad_true)
# Then compare
@test isapprox(value, value_true; atol=value_atol)
@test isapprox(grad, grad_true; atol=grad_atol)
grad_true = collect(grad_true)

Check warning on line 205 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L205

Added line #L205 was not covered by tests

exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
isapprox(value, value_true; atol=value_atol) || exc()
isapprox(grad, grad_true; atol=grad_atol) || exc()

Check warning on line 209 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L207-L209

Added lines #L207 - L209 were not covered by tests
else
value_true = nothing
grad_true = nothing
Expand Down
4 changes: 2 additions & 2 deletions src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ function tilde_assume(
lp = Bijectors.logpdf_with_trans(right, r, !isinverse)

if istrans(vi, vn)
@assert isinverse "Trying to link already transformed variables"
isinverse || @warn "Trying to link an already transformed variable ($vn)"
else
@assert !isinverse "Trying to invlink non-transformed variables"
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
end

# Only transform if `!isinverse` since `vi[vn, right]`
Expand Down
18 changes: 10 additions & 8 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@ using DynamicPPL: LogDensityFunction
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = LogDensityFunction(m, varinfo)
linked_varinfo = DynamicPPL.link(varinfo, m)
f = LogDensityFunction(m, linked_varinfo)
x = DynamicPPL.getparams(f)
# Calculate reference logp + gradient of logp using ForwardDiff
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)

@testset "$adtype" for adtype in test_adtypes
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
@info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype"

# Put predicates here to avoid long lines
is_mooncake = adtype isa AutoMooncake
is_1_10 = v"1.10" <= VERSION < v"1.11"
is_1_11 = v"1.11" <= VERSION < v"1.12"
is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict}
is_svi_vnv =
linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector}
is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict}

# Mooncake doesn't work with several combinations of SimpleVarInfo.
if is_mooncake && is_1_11 && is_svi_vnv
Expand All @@ -56,12 +58,12 @@ using DynamicPPL: LogDensityFunction
ref_ldf, adtype
)
else
DynamicPPL.TestUtils.AD.run_ad(
@test DynamicPPL.TestUtils.AD.run_ad(
m,
adtype;
varinfo=varinfo,
varinfo=linked_varinfo,
expected_value_and_grad=(ref_logp, ref_grad),
)
) isa Any
end
end
end
Expand Down
6 changes: 0 additions & 6 deletions test/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,6 @@
# Should be approx. the same as the "lazy" transformation.
@test logjoint(model, vi_linked) ≈ lp_linked

# TODO: Should not `VarInfo` also error here? The current implementation
# only warns and acts as a no-op.
if vi isa SimpleVarInfo
@test_throws AssertionError link!!(vi_linked, model)
end

# `invlink!!`
vi_invlinked = invlink!!(deepcopy(vi_linked), model)
lp_invlinked = getlogp(vi_invlinked)
Expand Down
Loading