Skip to content

Commit 4a77b5b

Browse files
willtebbuttararslan
authored andcommitted
Tidy up tests a bit (#17)
1 parent e06cfd1 commit 4a77b5b

File tree

2 files changed

+50
-15
lines changed

2 files changed

+50
-15
lines changed

test/rules/base.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ end
7676
@test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄))
7777
@test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄))
7878

79-
test_adjoint!(rand(3, 2), dx, z̄, z̄ * y')
80-
test_adjoint!(rand(2, 5), dy, z̄, x' * z̄)
79+
test_accumulation(rand(3, 2), dx, z̄, z̄ * y')
80+
test_accumulation(rand(2, 5), dy, z̄, x' * z̄)
8181
end
8282
@testset "hypot(x, y)" begin
8383
x, y = rand(2)

test/test_util.jl

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
4242
@test cr_isapprox(x̄_ad, x̄_fd, rtol, atol)
4343

4444
# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
45-
test_adjoint!(x̄, dx, ȳ, x̄_ad)
45+
test_accumulation(x̄, dx, ȳ, x̄_ad)
46+
test_accumulation(Zero(), dx, ȳ, x̄_ad)
4647
end
4748

4849
function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm)
@@ -56,7 +57,11 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
5657
@test all(map((Δx_ad, Δx_fd)->cr_isapprox(Δx_ad, Δx_fd, rtol, atol), Δxs_ad, Δxs_fd))
5758

5859
# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
59-
map((x̄, Δx_rule, Δx_ad)->test_adjoint!(x̄, Δx_rule, ȳ, Δx_ad), x̄s, Δx_rules, Δxs_ad)
60+
map(x̄s, Δx_rules, Δxs_ad) do x̄, Δx_rule, Δx_ad
61+
test_accumulation(x̄, Δx_rule, ȳ, Δx_ad)
62+
test_accumulation(Zero(), Δx_rule, ȳ, Δx_ad)
63+
return nothing
64+
end
6065
end
6166

6267
function cr_isapprox(d_ad, d_fd, rtol, atol)
@@ -75,21 +80,51 @@ function cr_isapprox(d_ad::Thunk, d_fd, rtol, atol)
7580
return isapprox(extern(d_ad), d_fd; rtol=rtol, atol=atol)
7681
end
7782

78-
function test_adjoint!(x̄, dx, ȳ, partial)
79-
x̄_old = copy(x̄)
80-
x̄_zeros = zero.(x̄)
83+
function test_accumulation(x̄, dx, ȳ, partial)
84+
@test all(extern(ChainRules.add(x̄, partial)) .== extern(x̄) .+ extern(partial))
85+
test_accumulate(x̄, dx, ȳ, partial)
86+
test_accumulate!(x̄, dx, ȳ, partial)
87+
test_store!(x̄, dx, ȳ, partial)
88+
return nothing
89+
end
90+
91+
function test_accumulate(x̄::Zero, dx, ȳ, partial)
92+
@test extern(accumulate(x̄, dx, ȳ)) == extern(partial)
93+
return nothing
94+
end
95+
96+
function test_accumulate(x̄::Number, dx, ȳ, partial)
97+
@test extern(accumulate(x̄, dx, ȳ)) == extern(x̄) + extern(partial)
98+
return nothing
99+
end
81100

82-
@test all(accumulate(Zero(), dx, ȳ) .== accumulate(x̄_zeros, dx, ȳ))
83-
@test all(accumulate(x̄, dx, ȳ) .== (x̄ .+ partial))
101+
function test_accumulate(x̄::AbstractArray, dx, ȳ, partial)
102+
x̄_old = copy(x̄)
103+
@test all(extern(accumulate(x̄, dx, ȳ)) .== (extern(x̄) .+ extern(partial)))
84104
@test== x̄_old
105+
return nothing
106+
end
85107

86-
accumulate!(x̄, dx, ȳ)
87-
@test== (x̄_old .+ partial)
88-
x̄ .= x̄_old
108+
test_accumulate!(x̄::Zero, dx, ȳ, partial) = nothing
109+
110+
function test_accumulate!(x̄::Number, dx, ȳ, partial)
111+
@test accumulate!(x̄, dx, ȳ) == accumulate(x̄, dx, ȳ)
112+
return nothing
113+
end
114+
115+
function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial)
116+
x̄_copy = copy(x̄)
117+
accumulate!(x̄_copy, dx, ȳ)
118+
@test extern(x̄_copy) == (extern(x̄) .+ extern(partial))
119+
return nothing
120+
end
89121

90-
store!(x̄, dx, ȳ)
91-
@test all(x̄ .== partial)
92-
x̄ .= x̄_old
122+
test_store!(x̄::Zero, dx, ȳ, partial) = nothing
123+
test_store!(x̄::Number, dx, ȳ, partial) = nothing
93124

125+
function test_store!(x̄::AbstractArray, dx, ȳ, partial)
126+
x̄_copy = copy(x̄)
127+
store!(x̄_copy, dx, ȳ)
128+
@test all(x̄_copy .== extern(partial))
94129
return nothing
95130
end

0 commit comments

Comments
 (0)