@@ -3,58 +3,64 @@ using FDM: jvp, j′vp
3
3
const _fdm = central_fdm (5 , 1 )
4
4
5
5
"""
6
- frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1))
6
+ frule_test(f, (x, ẋ)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs... )
7
7
8
8
# Arguments
9
9
- `f`: Function for which the `frule` should be tested.
10
10
- `x`: input at which to evaluate `f` (should generally be set randomly).
11
11
- `ẋ`: differential w.r.t. `x` (should generally be set randomly).
12
+
13
+ All keyword arguments except for `fdm` are passed to `isapprox`.
12
14
"""
13
- function frule_test (f, (x, ẋ); rtol= 1e-9 , atol= 1e-9 , fdm= _fdm)
14
- return frule_test (f, ((x, ẋ),); rtol= rtol, atol= atol, fdm= fdm)
15
+ function frule_test (f, (x, ẋ); rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, kwargs ... )
16
+ return frule_test (f, ((x, ẋ),); rtol= rtol, atol= atol, fdm= fdm, kwargs ... )
15
17
end
16
18
17
- function frule_test (f, xẋs:: Tuple{Any, Any} ...; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm)
19
+ function frule_test (f, xẋs:: Tuple{Any, Any} ...; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, kwargs ... )
18
20
xs, ẋs = collect (zip (xẋs... ))
19
21
Ω, dΩ_rule = ChainRules. frule (f, xs... )
20
22
@test f (xs... ) == Ω
21
23
22
24
dΩ_ad, dΩ_fd = dΩ_rule (ẋs... ), jvp (fdm, xs-> f (xs... ), (xs, ẋs))
23
- @test cr_isapprox (dΩ_ad, dΩ_fd, rtol, atol)
25
+ @test isapprox (dΩ_ad, dΩ_fd; rtol= rtol , atol= atol, kwargs ... )
24
26
end
25
27
26
28
"""
27
- rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1))
29
+ rrule_test(f, ȳ, (x, x̄)...; rtol=1e-9, atol=1e-9, fdm=central_fdm(5, 1), kwargs... )
28
30
29
31
# Arguments
30
32
- `f`: Function to which rule should be applied.
31
33
- `ȳ`: adjoint w.r.t. output of `f` (should generally be set randomly).
32
34
- `x`: input at which to evaluate `f` (should generally be set randomly).
33
35
- `x̄`: currently accumulated adjoint (should generally be set randomly).
36
+
37
+ All keyword arguments except for `fdm` are passed to `isapprox`.
34
38
"""
35
- function rrule_test (f, ȳ, (x, x̄):: Tuple{Any, Any} ; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm)
39
+ function rrule_test (f, ȳ, (x, x̄):: Tuple{Any, Any} ; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, kwargs ... )
36
40
# Check correctness of evaluation.
37
41
fx, dx = ChainRules. rrule (f, x)
38
42
@test fx ≈ f (x)
39
43
40
44
# Correctness testing via finite differencing.
41
45
x̄_ad, x̄_fd = dx (ȳ), j′vp (fdm, f, ȳ, x)
42
- @test cr_isapprox (x̄_ad, x̄_fd, rtol, atol)
46
+ @test isapprox (x̄_ad, x̄_fd; rtol= rtol , atol= atol, kwargs ... )
43
47
44
48
# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
45
49
test_accumulation (x̄, dx, ȳ, x̄_ad)
46
50
test_accumulation (Zero (), dx, ȳ, x̄_ad)
47
51
end
48
52
49
- function rrule_test (f, ȳ, xx̄s:: Tuple{Any, Any} ...; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm)
53
+ function rrule_test (f, ȳ, xx̄s:: Tuple{Any, Any} ...; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm, kwargs ... )
50
54
# Check correctness of evaluation.
51
55
xs, x̄s = collect (zip (xx̄s... ))
52
56
Ω, Δx_rules = ChainRules. rrule (f, xs... )
53
57
@test f (xs... ) == Ω
54
58
55
59
# Correctness testing via finite differencing.
56
60
Δxs_ad, Δxs_fd = map (Δx_rule-> Δx_rule (ȳ), Δx_rules), j′vp (fdm, f, ȳ, xs... )
57
- @test all (map ((Δx_ad, Δx_fd)-> cr_isapprox (Δx_ad, Δx_fd, rtol, atol), Δxs_ad, Δxs_fd))
61
+ @test all (zip (Δxs_ad, Δxs_fd)) do (Δx_ad, Δx_fd)
62
+ isapprox (Δx_ad, Δx_fd; rtol= rtol, atol= atol, kwargs... )
63
+ end
58
64
59
65
# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
60
66
map (x̄s, Δx_rules, Δxs_ad) do x̄, Δx_rule, Δx_ad
@@ -64,20 +70,17 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
64
70
end
65
71
end
66
72
67
- function cr_isapprox (d_ad, d_fd, rtol, atol)
68
- return isapprox (d_ad, d_fd; rtol= rtol, atol= atol)
69
- end
70
- function cr_isapprox (ad:: Wirtinger , fd, rtol, atol)
73
+ function Base. isapprox (ad:: Wirtinger , fd; kwargs... )
71
74
error (" Finite differencing with Wirtinger rules not implemented" )
72
75
end
73
- function cr_isapprox (d_ad:: Casted , d_fd, rtol, atol )
74
- return all (isapprox .(extern (d_ad), d_fd; rtol = rtol, atol = atol ))
76
+ function Base . isapprox (d_ad:: Casted , d_fd; kwargs ... )
77
+ return all (isapprox .(extern (d_ad), d_fd; kwargs ... ))
75
78
end
76
- function cr_isapprox (d_ad:: DNE , d_fd, rtol, atol )
79
+ function Base . isapprox (d_ad:: DNE , d_fd; kwargs ... )
77
80
error (" Tried to differentiate w.r.t. a DNE" )
78
81
end
79
- function cr_isapprox (d_ad:: Thunk , d_fd, rtol, atol )
80
- return isapprox (extern (d_ad), d_fd; rtol = rtol, atol = atol )
82
+ function Base . isapprox (d_ad:: Thunk , d_fd; kwargs ... )
83
+ return isapprox (extern (d_ad), d_fd; kwargs ... )
81
84
end
82
85
83
86
function test_accumulation (x̄, dx, ȳ, partial)
0 commit comments