Skip to content

Commit cb0864f

Browse files
author
Miha Zgubic
committed
Revert "autotangent mapreduce tests"
This reverts commit 440591e.
1 parent 15f644c commit cb0864f

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

test/rulesets/Base/mapreduce.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,39 @@
22
@testset "sum" begin
33
sizes = (3, 4, 7)
44
@testset "dims = $dims" for dims in (:, 1)
5+
fkwargs = (dims=dims,)
56
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
67
s = sizes[1:N]
78
x = randn(T, s...)
8-
test_frule(sum, x; fkwargs=(;dims=dims))
9-
test_rrule(sum, x; fkwargs=(;dims=dims))
9+
= randn(T, s...)
10+
= randn(T, s...)
11+
y = sum(x; dims=dims)
12+
Δy = randn(eltype(y), size(y)...)
13+
frule_test(sum, (x, ẋ); fkwargs=fkwargs)
14+
rrule_test(sum, Δy, (x, x̄); fkwargs=fkwargs)
1015
end
1116
end
1217
end # sum
1318

1419
@testset "sum abs2" begin
1520
sizes = (3, 4, 7)
1621
@testset "dims = $dims" for dims in (:, 1)
22+
fkwargs = (dims=dims,)
1723
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
1824
s = sizes[1:N]
19-
x = randn(T, s...)
20-
test_frule(sum, abs2, x; fkwargs=(;dims=dims))
21-
test_rrule(sum, abs2 nothing, x; fkwargs=(;dims=dims))
25+
x, ẋ, x̄ = randn(T, s...), randn(T, s...), randn(T, s...)
26+
y = sum(abs2, x; dims=dims)
27+
Δy = randn(eltype(y), size(y)...)
28+
@testset "frule" begin
29+
# can't use frule_test here because it doesn't yet ignore nothing tangents
30+
y_ad, ẏ_ad = frule((Zero(), Zero(), ẋ), sum, abs2, x; dims=dims)
31+
@test y_ad == y
32+
ẏ_fd = jvp(_fdm, z -> sum(abs2, z; dims=dims), (x, ẋ))
33+
@test ẏ_ad ẏ_fd
34+
end
35+
@testset "rrule" begin
36+
rrule_test(sum, Δy, (abs2, nothing), (x, x̄); fkwargs=fkwargs)
37+
end
2238
end
2339
end
2440
end # sum abs2

0 commit comments

Comments
 (0)