Skip to content

Commit e049442

Browse files
authored
Merge pull request #364 from JuliaDiff/mz/arraymathtests
arraymath.jl autotangent
2 parents fb7b5c6 + e01c712 commit e049442

File tree

1 file changed

+28
-39
lines changed

1 file changed

+28
-39
lines changed

test/rulesets/Base/arraymath.jl

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,63 @@
11
@testset "arraymath" begin
22
@testset "inv(::Matrix{$T})" for T in (Float64, ComplexF64)
3-
N = 3
4-
B = generate_well_conditioned_matrix(T, N)
5-
frule_test(inv, (B, randn(T, N, N)))
6-
rrule_test(inv, randn(T, N, N), (B, randn(T, N, N)))
3+
B = generate_well_conditioned_matrix(T, 3)
4+
test_frule(inv, B)
5+
test_rrule(inv, B)
76
end
87

98
@testset "*: $T" for T in (Float64, ComplexF64)
109
(a) = round.(5*randn(T, a)) # Helper to generate nice random values
1110
(a, b) = ((a, b)) # matrix
1211
() = only((())) # scalar
1312

14-
(a) = ((a), (a)) # Helper to generate random matrix and its cotangent
15-
(a, b) = ((a, b)) #matrix
16-
() = (()) # scalar
17-
1813
@testset "Scalar-Array $dims" for dims in ((3,), (5,4), (2, 3, 4, 5))
19-
rrule_test(*, (dims), (), (dims))
20-
rrule_test(*, (dims), (dims), ())
14+
test_rrule(*, (), (dims))
15+
test_rrule(*, (dims), ())
2116
end
2217

2318
@testset "AbstractMatrix-AbstractVector n=$n, m=$m" for n in (2, 3), m in (4, 5)
2419
@testset "Array" begin
25-
rrule_test(*, (n), n m, (m))
20+
test_rrule(*, n m, (m))
2621
end
2722
end
2823

2924
@testset "AbstractVector-AbstractMatrix n=$n, m=$m" for n in (2, 3), m in (4, 5)
3025
@testset "Array" begin
31-
rrule_test(*, n m, (n), 1 m)
26+
test_rrule(*, (n), 1 m)
3227
end
3328
end
3429

3530
@testset "AbstractMatrix-AbstractMatrix" begin
3631
@testset "Matrix * Matrix n=$n, m=$m, p=$p" for n in (2, 5), m in (2, 4), p in (2, 3)
3732
@testset "Array" begin
38-
rrule_test(*, np, (nm), (mp))
33+
test_rrule(*, (nm), (mp))
3934
end
4035

4136
@testset "SubArray - $indexname" for (indexname, m_index) in (
42-
("fast", :), ("slow", Ref(m:-1:1))
37+
("fast", :), ("slow", m:-1:1)
4338
)
44-
rrule_test(*, np, view.(nm, :, m_index), view.(mp, m_index, :))
45-
rrule_test(*, np, nm, view.(mp, m_index, :))
46-
rrule_test(*, np, view.(nm, :, m_index), mp)
39+
test_rrule(*, view(nm, :, m_index), view(mp, m_index, :))
40+
test_rrule(*, nm, view(mp, m_index, :))
41+
test_rrule(*, view(nm, :, m_index), mp)
4742
end
4843

4944
@testset "Adjoints and Transposes" begin
50-
rrule_test(*, np, Transpose.(mn), Transpose.(pm))
51-
rrule_test(*, np, Adjoint.(mn), Adjoint.(pm))
45+
test_rrule(*, Transpose(mn) Transpose(mn), Transpose(pm) Transpose(pm))
46+
test_rrule(*, Adjoint(mn) Adjoint(mn), Adjoint(pm) Adjoint(pm))
5247

53-
rrule_test(*, np, Transpose.(mn), (mp))
54-
rrule_test(*, np, Adjoint.(mn), (mp))
48+
test_rrule(*, Transpose(mn) Transpose(mn), (mp))
49+
test_rrule(*, Adjoint(mn) Adjoint(mn), (mp))
5550

56-
rrule_test(*, np, (n₂m), Transpose.(pm))
57-
rrule_test(*, np, (n₂m), Adjoint.(pm))
51+
test_rrule(*, (nm), Transpose(pm) Transpose(pm))
52+
test_rrule(*, (nm), Adjoint(pm) Adjoint(pm))
5853
end
5954
end
6055
end
6156

6257
@testset "Covector * Vector n=$n" for n in (3, 5)
6358
@testset "$f" for f in (adjoint, transpose)
6459
# This should be same as dot product and give a scalar
65-
rrule_test(*, (), f.((n)), (n))
60+
test_rrule(*, f((n)) f((n)), (n))
6661
end
6762
end
6863
end
@@ -73,45 +68,39 @@
7368
for n in 3:5, m in 3:5
7469
A = randn(m, n)
7570
B = randn(m, n)
76-
= randn(size(f(A, B)))
77-
rrule_test(f, Ȳ, (A, randn(m, n)), (B, randn(m, n)))
71+
test_rrule(f, A, B)
7872
end
7973
end
8074
@testset "Vector" begin
8175
x = randn(10)
8276
y = randn(10)
83-
= randn(size(f(x, y))...)
84-
rrule_test(f, ȳ, (x, randn(10)), (y, randn(10)))
77+
test_rrule(f, x, y)
8578
end
8679
if f == (\)
8780
@testset "Matrix $f Vector" begin
8881
X = randn(10, 4)
8982
y = randn(10)
90-
= randn(size(f(X, y))...)
91-
rrule_test(f, ȳ, (X, randn(size(X))), (y, randn(10)))
83+
test_rrule(f, X, y)
9284
end
9385
@testset "Vector $f Matrix" begin
9486
x = randn(10)
9587
Y = randn(10, 4)
96-
= randn(size(f(x, Y))...)
97-
rrule_test(f, ȳ, (x, randn(size(x))), (Y, randn(size(Y))))
88+
test_rrule(f, x, Y; output_tangent=Transpose(rand(4)))
9889
end
9990
end
10091
end
10192
@testset "/ and \\ Scalar-AbstractArray" begin
10293
A = randn(3, 4, 5)
103-
= randn(3, 4, 5)
104-
= randn(3, 4, 5)
105-
rrule_test(/, Ȳ, (A, Ā), (7.2, 2.3))
106-
rrule_test(\, Ȳ, (7.2, 2.3), (A, Ā))
94+
test_rrule(/, A, 7.2)
95+
test_rrule(\, 7.2, A)
10796
end
10897

10998

11099
@testset "negation" begin
111100
A = randn(4, 4)
112101
= randn(4, 4)
113-
= randn(4, 4)
114-
rrule_test(-, Ȳ, (A, Ā))
115-
rrule_test(-, Diagonal(Ȳ), (Diagonal(A), Diagonal(Ā)))
102+
103+
test_rrule(-, A)
104+
test_rrule(-, Diagonal(A); output_tangent=Diagonal(Ā))
116105
end
117106
end

0 commit comments

Comments
 (0)