|
1 | 1 | @testset "reshape" begin
|
2 |
| - x = rand(4, 5) |
3 |
| - x̄ = rand(4, 5) |
4 |
| - |
5 |
| - ȳ = rand(2, 10) |
6 |
| - |
7 |
| - rrule_test(reshape, ȳ, (x, x̄), ((2, 10), nothing)) |
8 |
| - rrule_test(reshape, ȳ, (x, x̄), (2, nothing), (10, nothing)) |
| 2 | + test_rrule(reshape, rand(4, 5), (2, 10) ⊢ nothing) |
| 3 | + test_rrule(reshape, rand(4, 5), 2 ⊢ nothing, 10 ⊢ nothing) |
9 | 4 | end
|
10 | 5 |
|
11 | 6 | @testset "hcat" begin
|
12 | 7 | A = randn(3, 2)
|
13 | 8 | B = randn(3)
|
14 | 9 | C = randn(3, 3)
|
15 |
| - H, pullback = rrule(hcat, A, B, C) |
16 |
| - @test H == hcat(A, B, C) |
17 |
| - H̄ = randn(3, 6) |
18 |
| - (ds, dA, dB, dC) = pullback(H̄) |
19 |
| - @test ds == NO_FIELDS |
20 |
| - @test dA ≈ view(H̄, :, 1:2) |
21 |
| - @test dB ≈ view(H̄, :, 3) |
22 |
| - @test dC ≈ view(H̄, :, 4:6) |
| 10 | + test_rrule(hcat, A, B, C; check_inferred=false) |
23 | 11 | end
|
24 | 12 |
|
25 | 13 | @testset "reduce hcat" begin
|
26 | 14 | A = randn(3, 2)
|
27 | 15 | B = randn(3, 1)
|
28 | 16 | C = randn(3, 3)
|
29 |
| - x = [A, B, C] |
30 |
| - H, pullback = rrule(reduce, hcat, x) |
31 |
| - @test H == reduce(hcat, x) |
32 |
| - H̄ = randn(3, 6) |
33 |
| - x̄ = randn.(size.(x)) |
34 |
| - rrule_test(reduce, H̄, (hcat, nothing), (x, x̄)) |
| 17 | + test_rrule(reduce, hcat ⊢ nothing, [A, B, C]) |
35 | 18 | end
|
36 | 19 |
|
37 | 20 | @testset "vcat" begin
|
38 | 21 | A = randn(2, 4)
|
39 | 22 | B = randn(1, 4)
|
40 | 23 | C = randn(3, 4)
|
41 |
| - V, pullback = rrule(vcat, A, B, C) |
42 |
| - @test V == vcat(A, B, C) |
43 |
| - V̄ = randn(6, 4) |
44 |
| - (ds, dA, dB, dC) = pullback(V̄) |
45 |
| - @test ds == NO_FIELDS |
46 |
| - @test dA ≈ view(V̄, 1:2, :) |
47 |
| - @test dB ≈ view(V̄, 3:3, :) |
48 |
| - @test dC ≈ view(V̄, 4:6, :) |
| 24 | + test_rrule(vcat, A, B, C; check_inferred=false) |
49 | 25 | end
|
50 | 26 |
|
51 | 27 | @testset "reduce vcat" begin
|
52 | 28 | A = randn(2, 4)
|
53 | 29 | B = randn(1, 4)
|
54 | 30 | C = randn(3, 4)
|
55 |
| - x = [A, B, C] |
56 |
| - V, pullback = rrule(reduce, vcat, x) |
57 |
| - @test V == reduce(vcat, x) |
58 |
| - V̄ = randn(6, 4) |
59 |
| - x̄ = randn.(size.(x)) |
60 |
| - rrule_test(reduce, V̄, (vcat, nothing), (x, x̄)) |
| 31 | + test_rrule(reduce, vcat ⊢ nothing, [A, B, C]) |
61 | 32 | end
|
62 | 33 |
|
63 | 34 | @testset "fill" begin
|
64 |
| - y, pullback = rrule(fill, 44, 4) |
65 |
| - @test y == [44, 44, 44, 44] |
66 |
| - (ds, dv, dd) = pullback(ones(4)) |
67 |
| - @test ds === NO_FIELDS |
68 |
| - @test dd isa DoesNotExist |
69 |
| - @test extern(dv) == 4 |
70 |
| - |
71 |
| - y, pullback = rrule(fill, 2.0, (3, 3, 3)) |
72 |
| - @test y == fill(2.0, (3, 3, 3)) |
73 |
| - (ds, dv, dd) = pullback(ones(3, 3, 3)) |
74 |
| - @test ds === NO_FIELDS |
75 |
| - @test dd isa DoesNotExist |
76 |
| - @test dv ≈ 27.0 |
| 35 | + test_rrule(fill, 44.0, 4 ⊢ nothing; check_inferred=false) |
| 36 | + test_rrule(fill, 2.0, (3, 3, 3) ⊢ nothing) |
77 | 37 | end
|
0 commit comments