Skip to content

Commit fb7b5c6

Browse files
authored
Merge pull request #363 from JuliaDiff/mz/mapreducetests
autotangent mapreduce.jl tests
2 parents 4acc061 + 1b28f0c commit fb7b5c6

File tree

1 file changed

+6
-24
lines changed

1 file changed

+6
-24
lines changed

test/rulesets/Base/mapreduce.jl

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,21 @@
22
@testset "sum" begin
33
sizes = (3, 4, 7)
44
@testset "dims = $dims" for dims in (:, 1)
5-
fkwargs = (dims=dims,)
65
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
7-
s = sizes[1:N]
8-
x = randn(T, s...)
9-
= randn(T, s...)
10-
= randn(T, s...)
11-
y = sum(x; dims=dims)
12-
Δy = randn(eltype(y), size(y)...)
13-
frule_test(sum, (x, ẋ); fkwargs=fkwargs)
14-
rrule_test(sum, Δy, (x, x̄); fkwargs=fkwargs)
6+
x = randn(T, sizes[1:N]...)
7+
test_frule(sum, x; fkwargs=(;dims=dims))
8+
test_rrule(sum, x; fkwargs=(;dims=dims))
159
end
1610
end
1711
end # sum
1812

1913
@testset "sum abs2" begin
2014
sizes = (3, 4, 7)
2115
@testset "dims = $dims" for dims in (:, 1)
22-
fkwargs = (dims=dims,)
2316
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
24-
s = sizes[1:N]
25-
x, ẋ, x̄ = randn(T, s...), randn(T, s...), randn(T, s...)
26-
y = sum(abs2, x; dims=dims)
27-
Δy = randn(eltype(y), size(y)...)
28-
@testset "frule" begin
29-
# can't use frule_test here because it doesn't yet ignore nothing tangents
30-
y_ad, ẏ_ad = frule((Zero(), Zero(), ẋ), sum, abs2, x; dims=dims)
31-
@test y_ad == y
32-
ẏ_fd = jvp(_fdm, z -> sum(abs2, z; dims=dims), (x, ẋ))
33-
@test ẏ_ad ẏ_fd
34-
end
35-
@testset "rrule" begin
36-
rrule_test(sum, Δy, (abs2, nothing), (x, x̄); fkwargs=fkwargs)
37-
end
17+
x = randn(T, sizes[1:N]...)
18+
test_frule(sum, abs2, x; fkwargs=(;dims=dims))
19+
test_rrule(sum, abs2 nothing, x; fkwargs=(;dims=dims))
3820
end
3921
end
4022
end # sum abs2

0 commit comments

Comments
 (0)