Skip to content

Commit e1043ae

Browse files
committed
Don't need to do elementwise check
1 parent 3587ce5 commit e1043ae

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/test_utils/ad.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,12 @@ Everything else is optional, and can be categorised into several groups:
194194
Both absolute and relative tolerances can be specified using the `atol` and
195195
`rtol` keyword arguments respectively. The behaviour of these is similar to
196196
`isapprox()`, i.e. the value and gradient are considered correct if either
197-
atol or rtol is satisfied. The default values are `1e-8` for `atol` and
197+
atol or rtol is satisfied. The default values are `100*eps()` for `atol` and
198198
`sqrt(eps())` for `rtol`.
199199
200-
Note that gradients are always compared elementwise.
200+
For the most part, it is the `rtol` check that is more meaningful, because
201+
we cannot know the magnitude of logp and its gradient a priori. The `atol`
202+
value is supplied to handle the case where gradients are equal to zero.
201203
202204
5. _Whether to output extra logging information._
203205
@@ -218,7 +220,7 @@ function run_ad(
218220
adtype::AbstractADType;
219221
test::Union{AbstractADCorrectnessTestSetting,Bool}=WithBackend(),
220222
benchmark::Bool=false,
221-
atol::AbstractFloat=1e-8,
223+
atol::AbstractFloat=100 * eps(),
222224
rtol::AbstractFloat=sqrt(eps()),
223225
rng::AbstractRNG=default_rng(),
224226
varinfo::AbstractVarInfo=link(VarInfo(rng, model), model),
@@ -264,9 +266,7 @@ function run_ad(
264266
verbose && println(" expected : $((value_true, grad_true))")
265267
exc() = throw(ADIncorrectException(value, value_true, grad, grad_true))
266268
isapprox(value, value_true; atol=atol, rtol=rtol) || exc()
267-
for (g, g_true) in zip(grad, grad_true)
268-
isapprox(g, g_true; atol=atol, rtol=rtol) || exc()
269-
end
269+
isapprox(grad, grad_true; atol=atol, rtol=rtol) || exc()
270270
end
271271

272272
# Benchmark

0 commit comments

Comments
 (0)