Skip to content

Commit aa8f158

Browse files
committed
WIP update dense.jl to use automatic tangents
1 parent efd4c0e commit aa8f158

File tree

1 file changed

+45
-75
lines changed

1 file changed

+45
-75
lines changed

test/rulesets/LinearAlgebra/dense.jl

Lines changed: 45 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,34 @@
11
@testset "dense" begin
22
@testset "dot" begin
33
@testset "Vector{$T}" for T in (Float64, ComplexF64)
4-
M = 3
5-
x, y = randn(T, M), randn(T, M)
6-
ẋ, ẏ = randn(T, M), randn(T, M)
7-
x̄, ȳ = randn(T, M), randn(T, M)
8-
frule_test(dot, (x, ẋ), (y, ẏ))
9-
rrule_test(dot, randn(T), (x, x̄), (y, ȳ))
4+
test_frule(dot, randn(T, 3), randn(T, 3))
5+
test_rrule(dot, randn(T, 3), randn(T, 3))
106
end
117
@testset "Matrix{$T}" for T in (Float64, ComplexF64)
12-
M, N = 3, 4
13-
x, y = randn(T, M, N), randn(T, M, N)
14-
ẋ, ẏ = randn(T, M, N), randn(T, M, N)
15-
x̄, ȳ = randn(T, M, N), randn(T, M, N)
16-
frule_test(dot, (x, ẋ), (y, ẏ))
17-
rrule_test(dot, randn(T), (x, x̄), (y, ȳ))
8+
test_frule(dot, randn(T, 3, 4), randn(T, 3, 4))
9+
test_rrule(dot, randn(T, 3, 4), randn(T, 3, 4))
1810
end
1911
@testset "Array{$T, 3}" for T in (Float64, ComplexF64)
20-
M, N, P = 3, 4, 5
21-
x, y = randn(T, M, N, P), randn(T, M, N, P)
22-
ẋ, ẏ = randn(T, M, N, P), randn(T, M, N, P)
23-
x̄, ȳ = randn(T, M, N, P), randn(T, M, N, P)
24-
frule_test(dot, (x, ẋ), (y, ẏ))
25-
rrule_test(dot, randn(T), (x, x̄), (y, ȳ))
12+
test_frule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5))
13+
test_rrule(dot, randn(T, 3, 4, 5), randn(T, 3, 4, 5))
2614
end
2715
@testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64)
28-
M, N = 3, 4
29-
x, A, y = randn(T, M), randn(T, M, N), randn(T, N)
30-
ẋ, Adot, ẏ = randn(T, M), randn(T, M, N), randn(T, N)
31-
x̄, Abar, ȳ = randn(T, M), randn(T, M, N), randn(T, N)
32-
frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ))
33-
rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ))
16+
test_frule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4))
17+
test_rrule(dot, randn(T, 3), randn(T, 3, 4), randn(T, 4))
3418
end
3519
permuteddimsarray(A) = PermutedDimsArray(A, (2,1))
3620
@testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray)
37-
M, N = 3, 4
38-
x, A, y = rand(T, M), F(rand(T, N, M)), rand(T, N)
39-
ẋ, Adot, ẏ = rand(T, M), F(rand(T, N, M)), rand(T, N)
40-
x̄, Abar, ȳ = rand(T, M), F(rand(T, N, M)), rand(T, N)
41-
frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ); rtol=1f-3)
42-
rrule_test(dot, float(rand(T)), (x, x̄), (A, Abar), (y, ȳ); rtol=1f-3)
21+
A = F(rand(T, 4, 3)) F(rand(T, 4, 3))
22+
test_frule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3)
23+
test_rrule(dot, rand(T, 3), A, rand(T, 4); rtol=1f-3)
4324
end
4425
end
45-
@testset "cross" begin
46-
@testset "frule" begin
47-
@testset "$T" for T in (Float64, ComplexF64)
48-
n = 3
49-
x, y = randn(T, n), randn(T, n)
50-
ẋ, ẏ = randn(T, n), randn(T, n)
51-
frule_test(cross, (x, ẋ), (y, ẏ))
52-
end
53-
end
54-
@testset "rrule" begin
55-
n = 3
56-
x, y = randn(n), randn(n)
57-
x̄, ȳ = randn(n), randn(n)
58-
ΔΩ = randn(n)
59-
rrule_test(cross, ΔΩ, (x, x̄), (y, ȳ))
60-
end
26+
27+
@testset "cross"
28+
test_frule(cross, randn(3), randn(3))
29+
test_frule(cross, randn(ComplexF64, 3), randn(ComplexF64, 3))
30+
test_rrule(cross, randn(3), randn(3))
31+
# No complex support for rrule(cross,...
6132
end
6233
@testset "pinv" begin
6334
@testset "$T" for T in (Float64, ComplexF64)
@@ -66,37 +37,38 @@
6637
@test rrule(pinv, zero(T))[2](randn(T))[2] zero(T)
6738
end
6839
@testset "Vector{$T}" for T in (Float64, ComplexF64)
69-
n = 3
70-
x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n)
71-
tol, ṫol, t̄ol = 0.0, randn(), randn()
72-
Δy = copyto!(similar(pinv(x)), randn(T, n))
73-
frule_test(pinv, (x, ẋ), (tol, ṫol))
40+
test_frule(pinv, randn(T, 3), 0.0)
41+
test_frule(pinv, randn(T, 3), 0.0)
42+
43+
# Checking types. TODO do we still need this?
44+
x = randn(T, 3)
45+
= randn(T, 3)
46+
Δy = copyto!(similar(pinv(x)), randn(T, 3))
7447
@test frule((Zero(), ẋ), pinv, x)[2] isa typeof(pinv(x))
75-
rrule_test(pinv, Δy, (x, x̄), (tol, t̄ol))
7648
@test rrule(pinv, x)[2](Δy)[2] isa typeof(x)
7749
end
50+
#TODO Everything after this point
7851
@testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint)
79-
n = 3
80-
x, ẋ, x̄ = F(randn(T, n)), F(randn(T, n)), F(randn(T, n))
52+
x, ẋ, x̄ = F(randn(T, 3)), F(randn(T, 3)), F(randn(T, 3))
8153
y = pinv(x)
82-
Δy = copyto!(similar(y), randn(T, n))
83-
frule_test(pinv, (x, ẋ))
54+
Δy = copyto!(similar(y), randn(T, 3))
55+
test_frule(pinv, (x, ẋ))
8456
y_fwd, ∂y_fwd = frule((Zero(), ẋ), pinv, x)
8557
@test y_fwd isa typeof(y)
8658
@test ∂y_fwd isa typeof(y)
87-
rrule_test(pinv, Δy, (x, x̄))
59+
test_rrule(pinv, Δy, (x, x̄))
8860
y_rev, back = rrule(pinv, x)
8961
@test y_rev isa typeof(y)
9062
@test back(Δy)[2] isa typeof(x)
9163
end
92-
@testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64),
64+
@testset "Matrix{$T} with size ($m,$3)" for T in (Float64, ComplexF64),
9365
m in 1:3,
94-
n in 1:3
66+
3 in 1:3
9567

96-
X, Ẋ, X̄ = randn(T, m, n), randn(T, m, n), randn(T, m, n)
68+
X, Ẋ, X̄ = randn(T, m, 3), randn(T, m, 3), randn(T, m, 3)
9769
ΔY = randn(T, size(pinv(X))...)
98-
frule_test(pinv, (X, Ẋ))
99-
rrule_test(pinv, ΔY, (X, X̄))
70+
test_frule(pinv, (X, Ẋ))
71+
test_rrule(pinv, ΔY, (X, X̄))
10072
end
10173
end
10274
@testset "$f" for f in (det, logdet)
@@ -110,24 +82,22 @@
11082
else
11183
kwargs = NamedTuple()
11284
end
113-
N = 3
114-
B = generate_well_conditioned_matrix(T, N)
115-
frule_test(f, (B, randn(T, N, N)); kwargs...)
116-
rrule_test(f, randn(T), (B, randn(T, N, N)); kwargs...)
85+
B = generate_well_conditioned_matrix(T, 4)
86+
test_frule(f, (B, randn(T, 4, 4)); kwargs...)
87+
test_rrule(f, randn(T), (B, randn(T, 4, 4)); kwargs...)
11788
end
11889
end
11990
@testset "logabsdet(::Matrix{$T})" for T in (Float64, ComplexF64)
120-
N = 3
121-
B = randn(T, N, N)
122-
frule_test(logabsdet, (B, randn(T, N, N)))
123-
rrule_test(logabsdet, (randn(), randn(T)), (B, randn(T, N, N)))
91+
B = randn(T, 4, 4)
92+
test_frule(logabsdet, (B, randn(T, 4, 4)))
93+
test_rrule(logabsdet, (randn(), randn(T)), (B, randn(T, 4, 4)))
12494
# test for opposite sign of determinant
125-
frule_test(logabsdet, (-B, randn(T, N, N)))
126-
rrule_test(logabsdet, (randn(), randn(T)), (-B, randn(T, N, N)))
95+
test_frule(logabsdet, (-B, randn(T, 4, 4)))
96+
test_rrule(logabsdet, (randn(), randn(T)), (-B, randn(T, 4, 4)))
12797
end
12898
@testset "tr" begin
129-
N = 4
130-
frule_test(tr, (randn(N, N), randn(N, N)))
131-
rrule_test(tr, randn(), (randn(N, N), randn(N, N)))
99+
test_frule(tr, (randn(4, 4), randn(4, 4)))
100+
test_rrule(tr, randn(), (randn(4, 4), randn(4, 4)))
132101
end
102+
==#
133103
end

0 commit comments

Comments
 (0)