Skip to content

Commit 566c355

Browse files
authored
Merge pull request #361 from JuliaDiff/ox/array_autotangent
array.jl autotangent
2 parents e50a028 + ec42600 commit 566c355

File tree

1 file changed

+8
-48
lines changed

1 file changed

+8
-48
lines changed

test/rulesets/Base/array.jl

Lines changed: 8 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,37 @@
11
@testset "reshape" begin
2-
x = rand(4, 5)
3-
= rand(4, 5)
4-
5-
= rand(2, 10)
6-
7-
rrule_test(reshape, ȳ, (x, x̄), ((2, 10), nothing))
8-
rrule_test(reshape, ȳ, (x, x̄), (2, nothing), (10, nothing))
2+
test_rrule(reshape, rand(4, 5), (2, 10) nothing)
3+
test_rrule(reshape, rand(4, 5), 2 nothing, 10 nothing)
94
end
105

116
@testset "hcat" begin
127
A = randn(3, 2)
138
B = randn(3)
149
C = randn(3, 3)
15-
H, pullback = rrule(hcat, A, B, C)
16-
@test H == hcat(A, B, C)
17-
= randn(3, 6)
18-
(ds, dA, dB, dC) = pullback(H̄)
19-
@test ds == NO_FIELDS
20-
@test dA view(H̄, :, 1:2)
21-
@test dB view(H̄, :, 3)
22-
@test dC view(H̄, :, 4:6)
10+
test_rrule(hcat, A, B, C; check_inferred=false)
2311
end
2412

2513
@testset "reduce hcat" begin
2614
A = randn(3, 2)
2715
B = randn(3, 1)
2816
C = randn(3, 3)
29-
x = [A, B, C]
30-
H, pullback = rrule(reduce, hcat, x)
31-
@test H == reduce(hcat, x)
32-
= randn(3, 6)
33-
= randn.(size.(x))
34-
rrule_test(reduce, H̄, (hcat, nothing), (x, x̄))
17+
test_rrule(reduce, hcat nothing, [A, B, C])
3518
end
3619

3720
@testset "vcat" begin
3821
A = randn(2, 4)
3922
B = randn(1, 4)
4023
C = randn(3, 4)
41-
V, pullback = rrule(vcat, A, B, C)
42-
@test V == vcat(A, B, C)
43-
= randn(6, 4)
44-
(ds, dA, dB, dC) = pullback(V̄)
45-
@test ds == NO_FIELDS
46-
@test dA view(V̄, 1:2, :)
47-
@test dB view(V̄, 3:3, :)
48-
@test dC view(V̄, 4:6, :)
24+
test_rrule(vcat, A, B, C; check_inferred=false)
4925
end
5026

5127
@testset "reduce vcat" begin
5228
A = randn(2, 4)
5329
B = randn(1, 4)
5430
C = randn(3, 4)
55-
x = [A, B, C]
56-
V, pullback = rrule(reduce, vcat, x)
57-
@test V == reduce(vcat, x)
58-
= randn(6, 4)
59-
= randn.(size.(x))
60-
rrule_test(reduce, V̄, (vcat, nothing), (x, x̄))
31+
test_rrule(reduce, vcat nothing, [A, B, C])
6132
end
6233

6334
@testset "fill" begin
64-
y, pullback = rrule(fill, 44, 4)
65-
@test y == [44, 44, 44, 44]
66-
(ds, dv, dd) = pullback(ones(4))
67-
@test ds === NO_FIELDS
68-
@test dd isa DoesNotExist
69-
@test extern(dv) == 4
70-
71-
y, pullback = rrule(fill, 2.0, (3, 3, 3))
72-
@test y == fill(2.0, (3, 3, 3))
73-
(ds, dv, dd) = pullback(ones(3, 3, 3))
74-
@test ds === NO_FIELDS
75-
@test dd isa DoesNotExist
76-
@test dv 27.0
35+
test_rrule(fill, 44.0, 4 nothing; check_inferred=false)
36+
test_rrule(fill, 2.0, (3, 3, 3) nothing)
7737
end

0 commit comments

Comments
 (0)