Skip to content

Commit 440591e

Browse files
author
Miha Zgubic
committed
autotangent mapreduce tests
1 parent 4acc061 commit 440591e

File tree

1 file changed

+5
-21
lines changed

1 file changed

+5
-21
lines changed

test/rulesets/Base/mapreduce.jl

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,23 @@
22
@testset "sum" begin
33
sizes = (3, 4, 7)
44
@testset "dims = $dims" for dims in (:, 1)
5-
fkwargs = (dims=dims,)
65
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
76
s = sizes[1:N]
87
x = randn(T, s...)
9-
= randn(T, s...)
10-
= randn(T, s...)
11-
y = sum(x; dims=dims)
12-
Δy = randn(eltype(y), size(y)...)
13-
frule_test(sum, (x, ẋ); fkwargs=fkwargs)
14-
rrule_test(sum, Δy, (x, x̄); fkwargs=fkwargs)
8+
test_frule(sum, x; fkwargs=(;dims=dims))
9+
test_rrule(sum, x; fkwargs=(;dims=dims))
1510
end
1611
end
1712
end # sum
1813

1914
@testset "sum abs2" begin
2015
sizes = (3, 4, 7)
2116
@testset "dims = $dims" for dims in (:, 1)
22-
fkwargs = (dims=dims,)
2317
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
2418
s = sizes[1:N]
25-
x, ẋ, x̄ = randn(T, s...), randn(T, s...), randn(T, s...)
26-
y = sum(abs2, x; dims=dims)
27-
Δy = randn(eltype(y), size(y)...)
28-
@testset "frule" begin
29-
# can't use frule_test here because it doesn't yet ignore nothing tangents
30-
y_ad, ẏ_ad = frule((Zero(), Zero(), ẋ), sum, abs2, x; dims=dims)
31-
@test y_ad == y
32-
ẏ_fd = jvp(_fdm, z -> sum(abs2, z; dims=dims), (x, ẋ))
33-
@test ẏ_ad ẏ_fd
34-
end
35-
@testset "rrule" begin
36-
rrule_test(sum, Δy, (abs2, nothing), (x, x̄); fkwargs=fkwargs)
37-
end
19+
x = randn(T, s...)
20+
test_frule(sum, abs2, x; fkwargs=(;dims=dims))
21+
test_rrule(sum, abs2 nothing, x; fkwargs=(;dims=dims))
3822
end
3923
end
4024
end # sum abs2

0 commit comments

Comments
 (0)