Skip to content

Commit f0fead0

Browse files
authored
Merge pull request #369 from JuliaDiff/ox/tandense
dense.jl autotangent
2 parents 5ed33f1 + d99bf59 commit f0fead0

File tree

1 file changed

+47
-75
lines changed

1 file changed

+47
-75
lines changed

test/rulesets/LinearAlgebra/dense.jl

Lines changed: 47 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,34 @@
11
@testset "dense" begin
22
@testset "dot" begin
33
@testset "Vector{$T}" for T in (Float64, ComplexF64)
4-
M = 3
5-
x, y = randn(T, M), randn(T, M)
6-
ẋ, ẏ = randn(T, M), randn(T, M)
7-
x̄, ȳ = randn(T, M), randn(T, M)
8-
frule_test(dot, (x, ẋ), (y, ẏ))
9-
rrule_test(dot, randn(T), (x, x̄), (y, ȳ))
4+
test_frule(dot, randn(T, 3), randn(T, 3))
5+
test_rrule(dot, randn(T, 3), randn(T, 3))
106
end
117
@testset "Matrix{$T}" for T in (Float64, ComplexF64)
12-
M, N = 3, 4
13-
x, y = randn(T, M, N), randn(T, M, N)
14-
ẋ, ẏ = randn(T, M, N), randn(T, M, N)
15-
x̄, ȳ = randn(T, M, N), randn(T, M, N)
16-
frule_test(dot, (x, ẋ), (y, ẏ))
17-
rrule_test(dot, randn(T), (x, x̄), (y, ȳ))
8+
test_frule(dot, randn(T, 3, 4), randn(T, 3, 4))
9+
test_rrule(dot, randn(T, 3, 4), randn(T, 3, 4))
1810
end
1911
@testset "Array{$T, 3}" for T in (Float64, ComplexF64)
20-
M, N, P = 3, 4, 5
21-
x, y = randn(T, M, N, P), randn(T, M, N, P)
22-
ẋ, ẏ = randn(T, M, N, P), randn(T, M, N, P)
23-
x̄, ȳ = randn(T, M, N, P), randn(T, M, N, P)
24-
frule_test(dot, (x, ẋ), (y, ẏ))
25-
rrule_test(dot, randn(T), (x, x̄), (y, ȳ))
12+
test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5))
13+
test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5))
2614
end
2715
@testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64)
28-
M, N = 3, 4
29-
x, A, y = randn(T, M), randn(T, M, N), randn(T, N)
30-
ẋ, Adot, ẏ = randn(T, M), randn(T, M, N), randn(T, N)
31-
x̄, Abar, ȳ = randn(T, M), randn(T, M, N), randn(T, N)
32-
frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ))
33-
rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ))
16+
test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4))
17+
test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4))
3418
end
3519
permuteddimsarray(A) = PermutedDimsArray(A, (2,1))
3620
@testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray)
37-
M, N = 3, 4
38-
x, A, y = rand(T, M), F(rand(T, N, M)), rand(T, N)
39-
ẋ, Adot, ẏ = rand(T, M), F(rand(T, N, M)), rand(T, N)
40-
x̄, Abar, ȳ = rand(T, M), F(rand(T, N, M)), rand(T, N)
41-
frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ); rtol=1f-3)
42-
rrule_test(dot, float(rand(T)), (x, x̄), (A, Abar), (y, ȳ); rtol=1f-3)
21+
A = F(rand(T, 4, 3)) F(rand(T, 4, 3))
22+
test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3)
23+
test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3)
4324
end
4425
end
26+
4527
@testset "cross" begin
46-
@testset "frule" begin
47-
@testset "$T" for T in (Float64, ComplexF64)
48-
n = 3
49-
x, y = randn(T, n), randn(T, n)
50-
ẋ, ẏ = randn(T, n), randn(T, n)
51-
frule_test(cross, (x, ẋ), (y, ẏ))
52-
end
53-
end
54-
@testset "rrule" begin
55-
n = 3
56-
x, y = randn(n), randn(n)
57-
x̄, ȳ = randn(n), randn(n)
58-
ΔΩ = randn(n)
59-
rrule_test(cross, ΔΩ, (x, x̄), (y, ȳ))
60-
end
28+
test_frule(cross, randn(3), randn(3))
29+
test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3))
30+
test_rrule(cross, randn(3), randn(3))
31+
# No complex support for rrule(cross,...
6132
end
6233
@testset "pinv" begin
6334
@testset "$T" for T in (Float64, ComplexF64)
@@ -66,25 +37,31 @@
6637
@test rrule(pinv, zero(T))[2](randn(T))[2] zero(T)
6738
end
6839
@testset "Vector{$T}" for T in (Float64, ComplexF64)
69-
n = 3
70-
x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n)
71-
tol, ṫol, t̄ol = 0.0, randn(), randn()
72-
Δy = copyto!(similar(pinv(x)), randn(T, n))
73-
frule_test(pinv, (x, ẋ), (tol, ṫol))
40+
test_frule(pinv, randn(T, 3), 0.0)
41+
test_frule(pinv, randn(T, 3), 0.0)
42+
43+
# Checking types. TODO do we still need this?
44+
x = randn(T, 3)
45+
= randn(T, 3)
46+
Δy = copyto!(similar(pinv(x)), randn(T, 3))
7447
@test frule((Zero(), ẋ), pinv, x)[2] isa typeof(pinv(x))
75-
rrule_test(pinv, Δy, (x, x̄), (tol, t̄ol))
7648
@test rrule(pinv, x)[2](Δy)[2] isa typeof(x)
7749
end
50+
7851
@testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint)
79-
n = 3
80-
x, ẋ, x̄ = F(randn(T, n)), F(randn(T, n)), F(randn(T, n))
52+
test_frule(pinv, F(randn(T, 3)) F(randn(T, 3)))
53+
test_rrule(pinv, F(randn(T, 3)))
54+
55+
# Check types.
56+
# TODO: Do we need this still?
57+
x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3))
8158
y = pinv(x)
82-
Δy = copyto!(similar(y), randn(T, n))
83-
frule_test(pinv, (x, ẋ))
59+
Δy = copyto!(similar(y), randn(T, 3))
60+
8461
y_fwd, ∂y_fwd = frule((Zero(), ẋ), pinv, x)
8562
@test y_fwd isa typeof(y)
8663
@test ∂y_fwd isa typeof(y)
87-
rrule_test(pinv, Δy, (x, x̄))
64+
8865
y_rev, back = rrule(pinv, x)
8966
@test y_rev isa typeof(y)
9067
@test back(Δy)[2] isa typeof(x)
@@ -93,10 +70,8 @@
9370
m in 1:3,
9471
n in 1:3
9572

96-
X, Ẋ, X̄ = randn(T, m, n), randn(T, m, n), randn(T, m, n)
97-
ΔY = randn(T, size(pinv(X))...)
98-
frule_test(pinv, (X, Ẋ))
99-
rrule_test(pinv, ΔY, (X, X̄))
73+
test_frule(pinv, randn(T, m, n))
74+
test_rrule(pinv, randn(T, m, n))
10075
end
10176
end
10277
@testset "$f" for f in (det, logdet)
@@ -105,29 +80,26 @@
10580
test_scalar(f, b)
10681
end
10782
@testset "$f(::Matrix{$T})" for T in (Float64, ComplexF64)
83+
B = generate_well_conditioned_matrix(T, 4)
10884
if f === logdet && float(T) <: Float32
109-
kwargs = (atol=1e-5, rtol=1e-5)
85+
test_frule(f, B; atol=1e-5, rtol=1e-5)
86+
test_rrule(f, B; atol=1e-5, rtol=1e-5)
11087
else
111-
kwargs = NamedTuple()
88+
test_frule(f, B)
89+
test_rrule(f, B)
11290
end
113-
N = 3
114-
B = generate_well_conditioned_matrix(T, N)
115-
frule_test(f, (B, randn(T, N, N)); kwargs...)
116-
rrule_test(f, randn(T), (B, randn(T, N, N)); kwargs...)
11791
end
11892
end
11993
@testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64)
120-
N = 3
121-
B = randn(T, N, N)
122-
frule_test(logabsdet, (B, randn(T, N, N)))
123-
rrule_test(logabsdet, (randn(), randn(T)), (B, randn(T, N, N)))
94+
B = randn(T, 4, 4)
95+
test_frule(logabsdet, B)
96+
test_rrule(logabsdet, B)
12497
# test for opposite sign of determinant
125-
frule_test(logabsdet, (-B, randn(T, N, N)))
126-
rrule_test(logabsdet, (randn(), randn(T)), (-B, randn(T, N, N)))
98+
test_frule(logabsdet, -B)
99+
test_rrule(logabsdet, -B)
127100
end
128101
@testset "tr" begin
129-
N = 4
130-
frule_test(tr, (randn(N, N), randn(N, N)))
131-
rrule_test(tr, randn(), (randn(N, N), randn(N, N)))
102+
test_frule(tr, randn(4, 4))
103+
test_rrule(tr, randn(4, 4))
132104
end
133105
end

0 commit comments

Comments
 (0)