Skip to content

Commit ec42600

Browse files
committed
use automatic tangents
1 parent 66c8880 commit ec42600

File tree

1 file changed

+6
-41
lines changed

1 file changed

+6
-41
lines changed

test/rulesets/Base/array.jl

Lines changed: 6 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,66 +7,31 @@ end
77
A = randn(3, 2)
88
B = randn(3)
99
C = randn(3, 3)
10-
H, pullback = rrule(hcat, A, B, C)
11-
@test H == hcat(A, B, C)
12-
= randn(3, 6)
13-
(ds, dA, dB, dC) = pullback(H̄)
14-
@test ds == NO_FIELDS
15-
@test dA view(H̄, :, 1:2)
16-
@test dB view(H̄, :, 3)
17-
@test dC view(H̄, :, 4:6)
10+
test_rrule(hcat, A, B, C; check_inferred=false)
1811
end
1912

2013
@testset "reduce hcat" begin
2114
A = randn(3, 2)
2215
B = randn(3, 1)
2316
C = randn(3, 3)
24-
x = [A, B, C]
25-
H, pullback = rrule(reduce, hcat, x)
26-
@test H == reduce(hcat, x)
27-
= randn(3, 6)
28-
= randn.(size.(x))
29-
rrule_test(reduce, H̄, (hcat, nothing), (x, x̄))
17+
test_rrule(reduce, hcat nothing, [A, B, C])
3018
end
3119

3220
@testset "vcat" begin
3321
A = randn(2, 4)
3422
B = randn(1, 4)
3523
C = randn(3, 4)
36-
V, pullback = rrule(vcat, A, B, C)
37-
@test V == vcat(A, B, C)
38-
= randn(6, 4)
39-
(ds, dA, dB, dC) = pullback(V̄)
40-
@test ds == NO_FIELDS
41-
@test dA view(V̄, 1:2, :)
42-
@test dB view(V̄, 3:3, :)
43-
@test dC view(V̄, 4:6, :)
24+
test_rrule(vcat, A, B, C; check_inferred=false)
4425
end
4526

4627
@testset "reduce vcat" begin
4728
A = randn(2, 4)
4829
B = randn(1, 4)
4930
C = randn(3, 4)
50-
x = [A, B, C]
51-
V, pullback = rrule(reduce, vcat, x)
52-
@test V == reduce(vcat, x)
53-
= randn(6, 4)
54-
= randn.(size.(x))
55-
rrule_test(reduce, V̄, (vcat, nothing), (x, x̄))
31+
test_rrule(reduce, vcat nothing, [A, B, C])
5632
end
5733

5834
@testset "fill" begin
59-
y, pullback = rrule(fill, 44, 4)
60-
@test y == [44, 44, 44, 44]
61-
(ds, dv, dd) = pullback(ones(4))
62-
@test ds === NO_FIELDS
63-
@test dd isa DoesNotExist
64-
@test extern(dv) == 4
65-
66-
y, pullback = rrule(fill, 2.0, (3, 3, 3))
67-
@test y == fill(2.0, (3, 3, 3))
68-
(ds, dv, dd) = pullback(ones(3, 3, 3))
69-
@test ds === NO_FIELDS
70-
@test dd isa DoesNotExist
71-
@test dv 27.0
35+
test_rrule(fill, 44.0, 4 nothing; check_inferred=false)
36+
test_rrule(fill, 2.0, (3, 3, 3) nothing)
7237
end

0 commit comments

Comments
 (0)