Skip to content

Commit 9eef762

Browse files
committed
Use atol and rtol
1 parent fc2df33 commit 9eef762

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

HISTORY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ The three keyword arguments, `test`, `reference_backend`, and `expected_value_an
1414
Please see the API documentation for more details.
1515
(The old `test=true` and `test=false` values are still valid, and you only need to adjust the invocation if you were explicitly passing the `reference_backend` or `expected_value_and_grad` arguments.)
1616

17+
There is now also an `rng` keyword argument to help seed parameter generation.
18+
19+
Finally, instead of specifying `value_atol` and `grad_atol`, you can now specify `atol` and `rtol` which are used for both value and gradient.
20+
Their semantics are the same as in Julia's `isapprox`; two values are equal if they satisfy either `atol` or `rtol`.
21+
Note that gradients are always compared elementwise (instead of using the norm, which is what `isapprox` does).
22+
1723
### Accumulators
1824

1925
This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:

src/test_utils/ad.jl

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Random: AbstractRNG, default_rng
1010
using Statistics: median
1111
using Test: @test
1212

13-
export ADResult, run_ad, ADIncorrectException
13+
export ADResult, run_ad, ADIncorrectException, WithBackend, WithExpectedResult, NoTest
1414

1515
"""
1616
AbstractADCorrectnessTestSetting
@@ -74,17 +74,18 @@ struct ADIncorrectException{T<:AbstractFloat} <: Exception
7474
end
7575

7676
"""
77-
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
77+
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat}
7878
7979
Data structure to store the results of the AD correctness test.
8080
8181
The type parameter `Tparams` is the numeric type of the parameters passed in;
82-
`Tresult` is the type of the value and the gradient.
82+
`Tresult` is the type of the value and the gradient; and `Ttol` is the type of the
83+
absolute and relative tolerances used for correctness testing.
8384
8485
# Fields
8586
$(TYPEDFIELDS)
8687
"""
87-
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
88+
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat,Ttol<:AbstractFloat}
8889
"The DynamicPPL model that was tested"
8990
model::Model
9091
"The VarInfo that was used"
@@ -93,10 +94,10 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
9394
params::Vector{Tparams}
9495
"The AD backend that was tested"
9596
adtype::AbstractADType
96-
"The absolute tolerance for the value of logp"
97-
value_atol::Tresult
98-
"The absolute tolerance for the gradient of logp"
99-
grad_atol::Tresult
97+
"Absolute tolerance used for correctness test"
98+
atol::Ttol
99+
"Relative tolerance used for correctness test"
100+
rtol::Ttol
100101
"The expected value of logp"
101102
value_expected::Union{Nothing,Tresult}
102103
"The expected gradient of logp"
@@ -115,8 +116,8 @@ end
115116
adtype::ADTypes.AbstractADType;
116117
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
117118
benchmark=false,
118-
value_atol=1e-6,
119-
grad_atol=1e-6,
119+
atol::AbstractFloat=1e-8,
120+
rtol::AbstractFloat=sqrt(eps()),
120121
varinfo::AbstractVarInfo=link(VarInfo(model), model),
121122
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
122123
verbose=true,
@@ -190,8 +191,13 @@ Everything else is optional, and can be categorised into several groups:
190191
191192
4. _How to specify the tolerances._ (Only if testing is enabled.)
192193
193-
The tolerances for the value and gradient can be set using `value_atol` and
194-
`grad_atol`. These default to 1e-6.
194+
Both absolute and relative tolerances can be specified using the `atol` and
195+
`rtol` keyword arguments respectively. The behaviour of these is similar to
196+
`isapprox()`, i.e. the value and gradient are considered correct if either
197+
atol or rtol is satisfied. The default values are `1e-8` for `atol` and
198+
`sqrt(eps())` for `rtol`.
199+
200+
Note that gradients are always compared elementwise.
195201
196202
5. _Whether to output extra logging information._
197203
@@ -212,8 +218,8 @@ function run_ad(
212218
adtype::AbstractADType;
213219
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
214220
benchmark::Bool=false,
215-
value_atol::AbstractFloat=1e-6,
216-
grad_atol::AbstractFloat=1e-6,
221+
atol::AbstractFloat=1e-8,
222+
rtol::AbstractFloat=sqrt(eps()),
217223
rng::AbstractRNG=default_rng(),
218224
varinfo::AbstractVarInfo=link(VarInfo(rng, model), model),
219225
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
@@ -257,8 +263,10 @@ function run_ad(
257263
# Perform testing
258264
verbose && println(" expected : $((value_true, grad_true))")
259265
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
260-
isapprox(value, value_true; atol=value_atol) || exc()
261-
isapprox(grad, grad_true; atol=grad_atol) || exc()
266+
isapprox(value, value_true; atol=atol, rtol=rtol) || exc()
267+
for (g, g_true) in zip(grad, grad_true)
268+
isapprox(g, g_true; atol=atol, rtol=rtol) || exc()
269+
end
262270
end
263271

264272
# Benchmark
@@ -277,8 +285,8 @@ function run_ad(
277285
varinfo,
278286
params,
279287
adtype,
280-
value_atol,
281-
grad_atol,
288+
atol,
289+
rtol,
282290
value_true,
283291
grad_true,
284292
value,

0 commit comments

Comments
 (0)