Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ Please see the API documentation for more details.

There is now also an `rng` keyword argument to help seed parameter generation.

Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
Instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.

Finally, the `ADResult` object returned by `run_ad` now has both `grad_time` and `primal_time` fields, which contain (respectively) the time it took to calculate the gradient of logp, and the time taken to calculate logp itself.
Times are reported in seconds.
Previously there was only a single `time_vs_primal` field which represented the ratio of these two.
You can of course access the same quantity by dividing `grad_time` by `primal_time`.

### `DynamicPPL.TestUtils.check_model`

You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`.
Expand Down
47 changes: 34 additions & 13 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,11 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloa
value_actual::Tresult
"The gradient of logp (calculated using `adtype`)"
grad_actual::Vector{Tresult}
"If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
time_vs_primal::Union{Nothing,Tresult}
"If benchmarking was requested, the time taken by the AD backend to evaluate the gradient
of logp"
grad_time::Union{Nothing,Tresult}
"If benchmarking was requested, the time taken by the AD backend to evaluate logp"
primal_time::Union{Nothing,Tresult}
end

"""
Expand All @@ -121,6 +124,8 @@ end
benchmark=false,
atol::AbstractFloat=1e-8,
rtol::AbstractFloat=sqrt(eps()),
getlogdensity::Function=getlogjoint_internal,
rng::Random.AbstractRNG=Random.default_rng(),
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
verbose=true,
Expand All @@ -143,7 +148,7 @@ ReverseDiff`.
There are two positional arguments, which absolutely must be provided:

1. `model` - The model being tested.
2. `adtype` - The AD backend being tested.
1. `adtype` - The AD backend being tested.

Everything else is optional, and can be categorised into several groups:

Expand All @@ -156,7 +161,7 @@ Everything else is optional, and can be categorised into several groups:
means that the parameters in the VarInfo have been transformed to
unconstrained Euclidean space if they aren't already in that space.

2. _How to specify the parameters._
1. _How to specify the parameters._

For maximum control over this, generate a vector of parameters yourself and
pass this as the `params` argument. If you don't specify this, it will be
Expand All @@ -174,7 +179,18 @@ Everything else is optional, and can be categorised into several groups:
prep_params)`. You could then evaluate the gradient at a different set of
parameters using the `params` keyword argument.

3. _How to specify the results to compare against._
1. _Which type of logp is being calculated._
By default, `run_ad` evaluates the 'internal log joint density' of the model,
i.e., the log joint density in the unconstrained space. Thus, for example, in
@model f() = x ~ LogNormal()
the internal log joint density is `logpdf(Normal(), log(x))`. This is the
relevant log density for e.g. Hamiltonian Monte Carlo samplers and is therefore
the most useful to test.
If you want the log joint density in the original model parameterisation, you
can use `getlogjoint`. Likewise, if you want only the prior or likelihood,
you can use `getlogprior` or `getloglikelihood`, respectively.

1. _How to specify the results to compare against._

Once logp and its gradient has been calculated with the specified `adtype`,
it can optionally be tested for correctness. The exact way this is tested
Expand All @@ -192,7 +208,7 @@ Everything else is optional, and can be categorised into several groups:
- `test=false` and `test=true` are synonyms for
`NoTest()` and `WithBackend(AutoForwardDiff())`, respectively.

4. _How to specify the tolerances._ (Only if testing is enabled.)
1. _How to specify the tolerances._ (Only if testing is enabled.)

Both absolute and relative tolerances can be specified using the `atol` and
`rtol` keyword arguments respectively. The behaviour of these is similar to
Expand All @@ -204,7 +220,7 @@ Everything else is optional, and can be categorised into several groups:
we cannot know the magnitude of logp and its gradient a priori. The `atol`
value is supplied to handle the case where gradients are equal to zero.

5. _Whether to output extra logging information._
1. _Whether to output extra logging information._

By default, this function prints messages when it runs. To silence it, set
`verbose=false`.
Expand Down Expand Up @@ -277,14 +293,18 @@ function run_ad(
end

# Benchmark
time_vs_primal = if benchmark
grad_time, primal_time = if benchmark
primal_benchmark = @be (ldf, params) logdensity(_[1], _[2])
grad_benchmark = @be (ldf, params) logdensity_and_gradient(_[1], _[2])
t = median(grad_benchmark).time / median(primal_benchmark).time
verbose && println("grad / primal : $(t)")
t
median_primal = median(primal_benchmark).time
median_grad = median(grad_benchmark).time
r(f) = round(f; sigdigits=4)
verbose && println(
"grad / primal : $(r(median_grad))/$(r(median_primal)) = $(r(median_grad / median_primal))",
)
(median_grad, median_primal)
else
nothing
nothing, nothing
end

return ADResult(
Expand All @@ -299,7 +319,8 @@ function run_ad(
grad_true,
value,
grad,
time_vs_primal,
grad_time,
primal_time,
)
end

Expand Down
Loading