|
7 | 7 | A = randn(3, 2)
|
8 | 8 | B = randn(3)
|
9 | 9 | C = randn(3, 3)
|
10 |
| - H, pullback = rrule(hcat, A, B, C) |
11 |
| - @test H == hcat(A, B, C) |
12 |
| - H̄ = randn(3, 6) |
13 |
| - (ds, dA, dB, dC) = pullback(H̄) |
14 |
| - @test ds == NO_FIELDS |
15 |
| - @test dA ≈ view(H̄, :, 1:2) |
16 |
| - @test dB ≈ view(H̄, :, 3) |
17 |
| - @test dC ≈ view(H̄, :, 4:6) |
| 10 | + test_rrule(hcat, A, B, C; check_inferred=false) |
18 | 11 | end
|
19 | 12 |
|
20 | 13 | @testset "reduce hcat" begin
|
21 | 14 | A = randn(3, 2)
|
22 | 15 | B = randn(3, 1)
|
23 | 16 | C = randn(3, 3)
|
24 |
| - x = [A, B, C] |
25 |
| - H, pullback = rrule(reduce, hcat, x) |
26 |
| - @test H == reduce(hcat, x) |
27 |
| - H̄ = randn(3, 6) |
28 |
| - x̄ = randn.(size.(x)) |
29 |
| - rrule_test(reduce, H̄, (hcat, nothing), (x, x̄)) |
| 17 | + test_rrule(reduce, hcat ⊢ nothing, [A, B, C]) |
30 | 18 | end
|
31 | 19 |
|
32 | 20 | @testset "vcat" begin
|
33 | 21 | A = randn(2, 4)
|
34 | 22 | B = randn(1, 4)
|
35 | 23 | C = randn(3, 4)
|
36 |
| - V, pullback = rrule(vcat, A, B, C) |
37 |
| - @test V == vcat(A, B, C) |
38 |
| - V̄ = randn(6, 4) |
39 |
| - (ds, dA, dB, dC) = pullback(V̄) |
40 |
| - @test ds == NO_FIELDS |
41 |
| - @test dA ≈ view(V̄, 1:2, :) |
42 |
| - @test dB ≈ view(V̄, 3:3, :) |
43 |
| - @test dC ≈ view(V̄, 4:6, :) |
| 24 | + test_rrule(vcat, A, B, C; check_inferred=false) |
44 | 25 | end
|
45 | 26 |
|
46 | 27 | @testset "reduce vcat" begin
|
47 | 28 | A = randn(2, 4)
|
48 | 29 | B = randn(1, 4)
|
49 | 30 | C = randn(3, 4)
|
50 |
| - x = [A, B, C] |
51 |
| - V, pullback = rrule(reduce, vcat, x) |
52 |
| - @test V == reduce(vcat, x) |
53 |
| - V̄ = randn(6, 4) |
54 |
| - x̄ = randn.(size.(x)) |
55 |
| - rrule_test(reduce, V̄, (vcat, nothing), (x, x̄)) |
| 31 | + test_rrule(reduce, vcat ⊢ nothing, [A, B, C]) |
56 | 32 | end
|
57 | 33 |
|
58 | 34 | @testset "fill" begin
|
59 |
| - y, pullback = rrule(fill, 44, 4) |
60 |
| - @test y == [44, 44, 44, 44] |
61 |
| - (ds, dv, dd) = pullback(ones(4)) |
62 |
| - @test ds === NO_FIELDS |
63 |
| - @test dd isa DoesNotExist |
64 |
| - @test extern(dv) == 4 |
65 |
| - |
66 |
| - y, pullback = rrule(fill, 2.0, (3, 3, 3)) |
67 |
| - @test y == fill(2.0, (3, 3, 3)) |
68 |
| - (ds, dv, dd) = pullback(ones(3, 3, 3)) |
69 |
| - @test ds === NO_FIELDS |
70 |
| - @test dd isa DoesNotExist |
71 |
| - @test dv ≈ 27.0 |
| 35 | + test_rrule(fill, 44.0, 4 ⊢ nothing; check_inferred=false) |
| 36 | + test_rrule(fill, 2.0, (3, 3, 3) ⊢ nothing) |
72 | 37 | end
|
0 commit comments