Skip to content

Commit 27ed10c

Browse files
author
Miha Zgubic
committed
resolve mapreduce
2 parents cb0864f + fb7b5c6 commit 27ed10c

File tree

1 file changed

+6
-22
lines changed

1 file changed

+6
-22
lines changed

test/rulesets/Base/mapreduce.jl

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,9 @@
44
@testset "dims = $dims" for dims in (:, 1)
55
fkwargs = (dims=dims,)
66
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
7-
s = sizes[1:N]
8-
x = randn(T, s...)
9-
= randn(T, s...)
10-
= randn(T, s...)
11-
y = sum(x; dims=dims)
12-
Δy = randn(eltype(y), size(y)...)
13-
frule_test(sum, (x, ẋ); fkwargs=fkwargs)
14-
rrule_test(sum, Δy, (x, x̄); fkwargs=fkwargs)
7+
x = randn(T, sizes[1:N]...)
8+
test_frule(sum, x; fkwargs=(;dims=dims))
9+
test_rrule(sum, x; fkwargs=(;dims=dims))
1510
end
1611
end
1712
end # sum
@@ -21,20 +16,9 @@
2116
@testset "dims = $dims" for dims in (:, 1)
2217
fkwargs = (dims=dims,)
2318
@testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
24-
s = sizes[1:N]
25-
x, ẋ, x̄ = randn(T, s...), randn(T, s...), randn(T, s...)
26-
y = sum(abs2, x; dims=dims)
27-
Δy = randn(eltype(y), size(y)...)
28-
@testset "frule" begin
29-
# can't use frule_test here because it doesn't yet ignore nothing tangents
30-
y_ad, ẏ_ad = frule((Zero(), Zero(), ẋ), sum, abs2, x; dims=dims)
31-
@test y_ad == y
32-
ẏ_fd = jvp(_fdm, z -> sum(abs2, z; dims=dims), (x, ẋ))
33-
@test ẏ_ad ẏ_fd
34-
end
35-
@testset "rrule" begin
36-
rrule_test(sum, Δy, (abs2, nothing), (x, x̄); fkwargs=fkwargs)
37-
end
19+
x = randn(T, sizes[1:N]...)
20+
test_frule(sum, abs2, x; fkwargs=(;dims=dims))
21+
test_rrule(sum, abs2 nothing, x; fkwargs=(;dims=dims))
3822
end
3923
end
4024
end # sum abs2

0 commit comments

Comments
 (0)