|
2 | 2 | @testset "sum" begin
|
3 | 3 | sizes = (3, 4, 7)
|
4 | 4 | @testset "dims = $dims" for dims in (:, 1)
|
| 5 | + fkwargs = (dims=dims,) |
5 | 6 | @testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
|
6 | 7 | s = sizes[1:N]
|
7 | 8 | x = randn(T, s...)
|
8 |
| - test_frule(sum, x; fkwargs=(;dims=dims)) |
9 |
| - test_rrule(sum, x; fkwargs=(;dims=dims)) |
| 9 | + ẋ = randn(T, s...) |
| 10 | + x̄ = randn(T, s...) |
| 11 | + y = sum(x; dims=dims) |
| 12 | + Δy = randn(eltype(y), size(y)...) |
| 13 | + frule_test(sum, (x, ẋ); fkwargs=fkwargs) |
| 14 | + rrule_test(sum, Δy, (x, x̄); fkwargs=fkwargs) |
10 | 15 | end
|
11 | 16 | end
|
12 | 17 | end # sum
|
13 | 18 |
|
14 | 19 | @testset "sum abs2" begin
|
15 | 20 | sizes = (3, 4, 7)
|
16 | 21 | @testset "dims = $dims" for dims in (:, 1)
|
| 22 | + fkwargs = (dims=dims,) |
17 | 23 | @testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
|
18 | 24 | s = sizes[1:N]
|
19 |
| - x = randn(T, s...) |
20 |
| - test_frule(sum, abs2, x; fkwargs=(;dims=dims)) |
21 |
| - test_rrule(sum, abs2 ⊢ nothing, x; fkwargs=(;dims=dims)) |
| 25 | + x, ẋ, x̄ = randn(T, s...), randn(T, s...), randn(T, s...) |
| 26 | + y = sum(abs2, x; dims=dims) |
| 27 | + Δy = randn(eltype(y), size(y)...) |
| 28 | + @testset "frule" begin |
| 29 | + # can't use frule_test here because it doesn't yet ignore nothing tangents |
| 30 | + y_ad, ẏ_ad = frule((Zero(), Zero(), ẋ), sum, abs2, x; dims=dims) |
| 31 | + @test y_ad == y |
| 32 | + ẏ_fd = jvp(_fdm, z -> sum(abs2, z; dims=dims), (x, ẋ)) |
| 33 | + @test ẏ_ad ≈ ẏ_fd |
| 34 | + end |
| 35 | + @testset "rrule" begin |
| 36 | + rrule_test(sum, Δy, (abs2, nothing), (x, x̄); fkwargs=fkwargs) |
| 37 | + end |
22 | 38 | end
|
23 | 39 | end
|
24 | 40 | end # sum abs2
|
|
0 commit comments