@@ -42,7 +42,8 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
42
42
@test cr_isapprox (x̄_ad, x̄_fd, rtol, atol)
43
43
44
44
# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
45
- test_adjoint! (x̄, dx, ȳ, x̄_ad)
45
+ test_accumulation (x̄, dx, ȳ, x̄_ad)
46
+ test_accumulation (Zero (), dx, ȳ, x̄_ad)
46
47
end
47
48
48
49
function rrule_test (f, ȳ, xx̄s:: Tuple{Any, Any} ...; rtol= 1e-9 , atol= 1e-9 , fdm= _fdm)
@@ -56,7 +57,11 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
56
57
@test all (map ((Δx_ad, Δx_fd)-> cr_isapprox (Δx_ad, Δx_fd, rtol, atol), Δxs_ad, Δxs_fd))
57
58
58
59
# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
59
- map ((x̄, Δx_rule, Δx_ad)-> test_adjoint! (x̄, Δx_rule, ȳ, Δx_ad), x̄s, Δx_rules, Δxs_ad)
60
+ map (x̄s, Δx_rules, Δxs_ad) do x̄, Δx_rule, Δx_ad
61
+ test_accumulation (x̄, Δx_rule, ȳ, Δx_ad)
62
+ test_accumulation (Zero (), Δx_rule, ȳ, Δx_ad)
63
+ return nothing
64
+ end
60
65
end
61
66
62
67
function cr_isapprox (d_ad, d_fd, rtol, atol)
@@ -75,21 +80,51 @@ function cr_isapprox(d_ad::Thunk, d_fd, rtol, atol)
75
80
return isapprox (extern (d_ad), d_fd; rtol= rtol, atol= atol)
76
81
end
77
82
78
- function test_adjoint! (x̄, dx, ȳ, partial)
79
- x̄_old = copy (x̄)
80
- x̄_zeros = zero .(x̄)
83
+ function test_accumulation (x̄, dx, ȳ, partial)
84
+ @test all (extern (ChainRules. add (x̄, partial)) .== extern (x̄) .+ extern (partial))
85
+ test_accumulate (x̄, dx, ȳ, partial)
86
+ test_accumulate! (x̄, dx, ȳ, partial)
87
+ test_store! (x̄, dx, ȳ, partial)
88
+ return nothing
89
+ end
90
+
91
+ function test_accumulate (x̄:: Zero , dx, ȳ, partial)
92
+ @test extern (accumulate (x̄, dx, ȳ)) == extern (partial)
93
+ return nothing
94
+ end
95
+
96
+ function test_accumulate (x̄:: Number , dx, ȳ, partial)
97
+ @test extern (accumulate (x̄, dx, ȳ)) == extern (x̄) + extern (partial)
98
+ return nothing
99
+ end
81
100
82
- @test all (accumulate (Zero (), dx, ȳ) .== accumulate (x̄_zeros, dx, ȳ))
83
- @test all (accumulate (x̄, dx, ȳ) .== (x̄ .+ partial))
101
+ function test_accumulate (x̄:: AbstractArray , dx, ȳ, partial)
102
+ x̄_old = copy (x̄)
103
+ @test all (extern (accumulate (x̄, dx, ȳ)) .== (extern (x̄) .+ extern (partial)))
84
104
@test x̄ == x̄_old
105
+ return nothing
106
+ end
85
107
86
- accumulate! (x̄, dx, ȳ)
87
- @test x̄ == (x̄_old .+ partial)
88
- x̄ .= x̄_old
108
+ test_accumulate! (x̄:: Zero , dx, ȳ, partial) = nothing
109
+
110
+ function test_accumulate! (x̄:: Number , dx, ȳ, partial)
111
+ @test accumulate! (x̄, dx, ȳ) == accumulate (x̄, dx, ȳ)
112
+ return nothing
113
+ end
114
+
115
+ function test_accumulate! (x̄:: AbstractArray , dx, ȳ, partial)
116
+ x̄_copy = copy (x̄)
117
+ accumulate! (x̄_copy, dx, ȳ)
118
+ @test extern (x̄_copy) == (extern (x̄) .+ extern (partial))
119
+ return nothing
120
+ end
89
121
90
- store! (x̄, dx, ȳ)
91
- @test all (x̄ .== partial)
92
- x̄ .= x̄_old
122
+ test_store! (x̄:: Zero , dx, ȳ, partial) = nothing
123
+ test_store! (x̄:: Number , dx, ȳ, partial) = nothing
93
124
125
+ function test_store! (x̄:: AbstractArray , dx, ȳ, partial)
126
+ x̄_copy = copy (x̄)
127
+ store! (x̄_copy, dx, ȳ)
128
+ @test all (x̄_copy .== extern (partial))
94
129
return nothing
95
130
end
0 commit comments