Skip to content

Commit c5fd246

Browse files
authored
Extend Base.isapprox instead of defining cr_isapprox (#26)
Currently in the testing utilities we define an internal `cr_isapprox` function that mimics `isapprox` except we only accept a subset of the keyword arguments accepted by `isapprox`. Instead, we can just define methods for `isapprox` on ChainRules-defined types for the purposes of testing. That way we can more easily take arbitrary keyword arguments and get Test's special-cased pretty printing for `@test a ≈ b`.
1 parent 2c009d1 commit c5fd246

File tree

1 file changed

+22
-19
lines changed

1 file changed

+22
-19
lines changed

test/test_util.jl

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,64 @@ using FDM: jvp, j′vp
33
const _fdm = central_fdm(5, 1)
44

55
"""
6-
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1))
6+
frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
77
88
# Arguments
99
- `f`: Function for which the `frule` should be tested.
1010
- `x`: input at which to evaluate `f` (should generally be set randomly).
1111
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
12+
13+
All keyword arguments except for `fdm` are passed to `isapprox`.
1214
"""
13-
function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm)
14-
return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm)
15+
function frule_test(f, (x, ẋ); rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
16+
return frule_test(f, ((x, ẋ),); rtol=rtol, atol=atol, fdm=fdm, kwargs...)
1517
end
1618

17-
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm)
19+
function frule_test(f, xẋs::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
1820
xs, ẋs = collect(zip(xẋs...))
1921
Ω, dΩ_rule = ChainRules.frule(f, xs...)
2022
@test f(xs...) == Ω
2123

2224
dΩ_ad, dΩ_fd = dΩ_rule(ẋs...), jvp(fdm, xs->f(xs...), (xs, ẋs))
23-
@test cr_isapprox(dΩ_ad, dΩ_fd, rtol, atol)
25+
@test isapprox(dΩ_ad, dΩ_fd; rtol=rtol, atol=atol, kwargs...)
2426
end
2527

2628
"""
27-
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1))
29+
rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs...)
2830
2931
# Arguments
3032
- `f`: Function to which rule should be applied.
3133
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
3234
- `x`: input at which to evaluate `f` (should generally be set randomly).
3335
- `x̄`: currently accumulated adjoint (should generally be set randomly).
36+
37+
All keyword arguments except for `fdm` are passed to `isapprox`.
3438
"""
35-
function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm)
39+
function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
3640
# Check correctness of evaluation.
3741
fx, dx = ChainRules.rrule(f, x)
3842
@test fx f(x)
3943

4044
# Correctness testing via finite differencing.
4145
x̄_ad, x̄_fd = dx(ȳ), j′vp(fdm, f, ȳ, x)
42-
@test cr_isapprox(x̄_ad, x̄_fd, rtol, atol)
46+
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
4347

4448
# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
4549
test_accumulation(x̄, dx, ȳ, x̄_ad)
4650
test_accumulation(Zero(), dx, ȳ, x̄_ad)
4751
end
4852

49-
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm)
53+
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm, kwargs...)
5054
# Check correctness of evaluation.
5155
xs, x̄s = collect(zip(xx̄s...))
5256
Ω, Δx_rules = ChainRules.rrule(f, xs...)
5357
@test f(xs...) == Ω
5458

5559
# Correctness testing via finite differencing.
5660
Δxs_ad, Δxs_fd = map(Δx_rule->Δx_rule(ȳ), Δx_rules), j′vp(fdm, f, ȳ, xs...)
57-
@test all(map((Δx_ad, Δx_fd)->cr_isapprox(Δx_ad, Δx_fd, rtol, atol), Δxs_ad, Δxs_fd))
61+
@test all(zip(Δxs_ad, Δxs_fd)) do (Δx_ad, Δx_fd)
62+
isapprox(Δx_ad, Δx_fd; rtol=rtol, atol=atol, kwargs...)
63+
end
5864

5965
# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
6066
map(x̄s, Δx_rules, Δxs_ad) do x̄, Δx_rule, Δx_ad
@@ -64,20 +70,17 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
6470
end
6571
end
6672

67-
function cr_isapprox(d_ad, d_fd, rtol, atol)
68-
return isapprox(d_ad, d_fd; rtol=rtol, atol=atol)
69-
end
70-
function cr_isapprox(ad::Wirtinger, fd, rtol, atol)
73+
function Base.isapprox(ad::Wirtinger, fd; kwargs...)
7174
error("Finite differencing with Wirtinger rules not implemented")
7275
end
73-
function cr_isapprox(d_ad::Casted, d_fd, rtol, atol)
74-
return all(isapprox.(extern(d_ad), d_fd; rtol=rtol, atol=atol))
76+
function Base.isapprox(d_ad::Casted, d_fd; kwargs...)
77+
return all(isapprox.(extern(d_ad), d_fd; kwargs...))
7578
end
76-
function cr_isapprox(d_ad::DNE, d_fd, rtol, atol)
79+
function Base.isapprox(d_ad::DNE, d_fd; kwargs...)
7780
error("Tried to differentiate w.r.t. a DNE")
7881
end
79-
function cr_isapprox(d_ad::Thunk, d_fd, rtol, atol)
80-
return isapprox(extern(d_ad), d_fd; rtol=rtol, atol=atol)
82+
function Base.isapprox(d_ad::Thunk, d_fd; kwargs...)
83+
return isapprox(extern(d_ad), d_fd; kwargs...)
8184
end
8285

8386
function test_accumulation(x̄, dx, ȳ, partial)

0 commit comments

Comments
 (0)