|
2 | 2 | @testset "sum" begin
|
3 | 3 | sizes = (3, 4, 7)
|
4 | 4 | @testset "dims = $dims" for dims in (:, 1)
|
5 |
| - fkwargs = (dims=dims,) |
6 | 5 | @testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
|
7 |
| - s = sizes[1:N] |
8 |
| - x = randn(T, s...) |
9 |
| - ẋ = randn(T, s...) |
10 |
| - x̄ = randn(T, s...) |
11 |
| - y = sum(x; dims=dims) |
12 |
| - Δy = randn(eltype(y), size(y)...) |
13 |
| - frule_test(sum, (x, ẋ); fkwargs=fkwargs) |
14 |
| - rrule_test(sum, Δy, (x, x̄); fkwargs=fkwargs) |
| 6 | + x = randn(T, sizes[1:N]...) |
| 7 | + test_frule(sum, x; fkwargs=(;dims=dims)) |
| 8 | + test_rrule(sum, x; fkwargs=(;dims=dims)) |
15 | 9 | end
|
16 | 10 | end
|
17 | 11 | end # sum
|
18 | 12 |
|
19 | 13 | @testset "sum abs2" begin
|
20 | 14 | sizes = (3, 4, 7)
|
21 | 15 | @testset "dims = $dims" for dims in (:, 1)
|
22 |
| - fkwargs = (dims=dims,) |
23 | 16 | @testset "Array{$N, $T}" for N in eachindex(sizes), T in (Float64, ComplexF64)
|
24 |
| - s = sizes[1:N] |
25 |
| - x, ẋ, x̄ = randn(T, s...), randn(T, s...), randn(T, s...) |
26 |
| - y = sum(abs2, x; dims=dims) |
27 |
| - Δy = randn(eltype(y), size(y)...) |
28 |
| - @testset "frule" begin |
29 |
| - # can't use frule_test here because it doesn't yet ignore nothing tangents |
30 |
| - y_ad, ẏ_ad = frule((Zero(), Zero(), ẋ), sum, abs2, x; dims=dims) |
31 |
| - @test y_ad == y |
32 |
| - ẏ_fd = jvp(_fdm, z -> sum(abs2, z; dims=dims), (x, ẋ)) |
33 |
| - @test ẏ_ad ≈ ẏ_fd |
34 |
| - end |
35 |
| - @testset "rrule" begin |
36 |
| - rrule_test(sum, Δy, (abs2, nothing), (x, x̄); fkwargs=fkwargs) |
37 |
| - end |
| 17 | + x = randn(T, sizes[1:N]...) |
| 18 | + test_frule(sum, abs2, x; fkwargs=(;dims=dims)) |
| 19 | + test_rrule(sum, abs2 ⊢ nothing, x; fkwargs=(;dims=dims)) |
38 | 20 | end
|
39 | 21 | end
|
40 | 22 | end # sum abs2
|
|
0 commit comments