Skip to content

Commit bd46c0a

Browse files
authored
Merge pull request #373 from JuliaDiff/ox/tanblas
blas.jl autotangent
2 parents dbc8134 + 2f749c4 commit bd46c0a

File tree

1 file changed

+24
-72
lines changed

1 file changed

+24
-72
lines changed

test/rulesets/LinearAlgebra/blas.jl

Lines changed: 24 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,130 +2,82 @@
22
@testset "dot" begin
33
@testset "all entries" begin
44
n = 10
5-
x, y = randn(n), randn(n)
6-
ẋ, ẏ = randn(n), randn(n)
7-
x̄, ȳ = randn(n), randn(n)
8-
frule_test(BLAS.dot, (x, ẋ), (y, ẏ))
9-
rrule_test(BLAS.dot, randn(), (x, x̄), (y, ȳ))
5+
test_frule(BLAS.dot, randn(n), randn(n))
6+
test_rrule(BLAS.dot, randn(n), randn(n))
107
end
118

129
@testset "over strides" begin
1310
n = 10
1411
incx = 2
1512
incy = 3
16-
x, y = randn(n * incx), randn(n * incy)
17-
x̄, ȳ = randn(n * incx), randn(n * incy)
18-
rrule_test(
13+
test_rrule(
1914
BLAS.dot,
20-
randn(),
21-
(n, nothing),
22-
(x, x̄),
23-
(incx, nothing),
24-
(y, ȳ),
25-
(incy, nothing),
15+
n nothing,
16+
randn(n * incx),
17+
incx nothing,
18+
randn(n * incy),
19+
incy nothing,
2620
)
2721
end
2822
end
2923

3024
@testset "nrm2" begin
3125
@testset "all entries" begin
32-
@testset "$T" for T in (Float64,ComplexF64)
26+
@testset "$T" for T in (Float64, ComplexF64)
3327
n = 10
34-
x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n)
35-
frule_test(BLAS.nrm2, (x, ẋ))
36-
rrule_test(BLAS.nrm2, randn(), (x, x̄))
28+
test_frule(BLAS.nrm2, randn(T, n))
29+
test_rrule(BLAS.nrm2, randn(T, n))
3730
end
3831
end
3932

4033
@testset "over strides" begin
4134
dims = (3, 2, 1)
4235
incx = 2
43-
@testset "Array{$T,$N}" for N in 1:length(dims), T in (Float64,ComplexF64)
36+
@testset "Array{$T,$N}" for N in eachindex(dims), T in (Float64, ComplexF64)
4437
s = (dims[1] * incx, dims[2:N]...)
4538
n = div(prod(s), incx)
46-
x, x̄ = randn(T, s), randn(T, s)
47-
rrule_test(
48-
BLAS.nrm2,
49-
randn(),
50-
(n, nothing),
51-
(x, x̄),
52-
(incx, nothing);
53-
atol=0,
54-
rtol=1e-5,
39+
test_rrule(
40+
BLAS.nrm2, n nothing, randn(T, s), incx nothing; atol=0, rtol=1e-5,
5541
)
5642
end
5743
end
5844
end
5945

6046
@testset "asum" begin
6147
@testset "all entries" begin
62-
@testset "$T" for T in (Float64,ComplexF64)
48+
@testset "$T" for T in (Float64, ComplexF64)
6349
n = 6
64-
x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n)
65-
frule_test(BLAS.asum, (x, ẋ))
66-
rrule_test(BLAS.asum, randn(), (x, x̄))
50+
test_frule(BLAS.asum, randn(T, n))
51+
test_rrule(BLAS.asum, randn(T, n))
6752
end
6853
end
6954

7055
@testset "over strides" begin
7156
dims = (2, 2, 1)
7257
incx = 2
73-
@testset "Array{$T,$N}" for N in 1:length(dims), T in (Float64,ComplexF64)
58+
@testset "Array{$T,$N}" for N in eachindex(dims), T in (Float64, ComplexF64)
7459
s = (dims[1] * incx, dims[2:N]...)
7560
n = div(prod(s), incx)
76-
x, x̄ = randn(T, s), randn(T, s)
77-
rrule_test( BLAS.asum, randn(), (n, nothing), (x, x̄), (incx, nothing))
61+
test_rrule( BLAS.asum, n nothing, randn(T, s), incx nothing)
7862
end
7963
end
8064
end
8165

8266
@testset "gemm" begin
83-
dims = 3:5
84-
for m in dims, n in dims, p in dims, tA in ('N', 'C', 'T'), tB in ('N', 'C', 'T'), T in (Float64, ComplexF64)
85-
α = randn(T)
67+
for m in 3:5, n in 3:5, p in 3:5, tA in ('N', 'C', 'T'), tB in ('N', 'C', 'T'), T in (Float64, ComplexF64)
8668
A = randn(T, tA === 'N' ? (m, n) : (n, m))
8769
B = randn(T, tB === 'N' ? (n, p) : (p, n))
88-
C = gemm(tA, tB, α, A, B)
89-
= randn(T, size(C)...)
90-
rrule_test(
91-
gemm,
92-
ȳ,
93-
(tA, nothing),
94-
(tB, nothing),
95-
(α, randn(T)),
96-
(A, randn(T, size(A))),
97-
(B, randn(T, size(B)));
98-
check_inferred=false,
99-
)
100-
101-
rrule_test(
102-
gemm,
103-
ȳ,
104-
(tA, nothing),
105-
(tB, nothing),
106-
(A, randn(T, size(A))),
107-
(B, randn(T, size(B)));
108-
check_inferred=false,
70+
test_rrule(gemm, tA nothing, tB nothing, A, B; check_inferred=false)
71+
test_rrule( # 5 arg version with scaling scalar
72+
gemm, tA nothing, tB nothing, randn(T), A, B; check_inferred=false,
10973
)
11074
end
11175
end
11276

11377
@testset "gemv" begin
11478
for n in 3:5, m in 3:5, t in ('N', 'C', 'T'), T in (Float64, ComplexF64)
115-
α = randn(T)
116-
A = randn(T, m, n)
11779
x = randn(T, t === 'N' ? n : m)
118-
y = gemv(t, α, A, x)
119-
= randn(T, size(y)...)
120-
rrule_test(
121-
gemv,
122-
ȳ,
123-
(t, nothing),
124-
(α, randn(T)),
125-
(A, randn(T, size(A))),
126-
(x, randn(T, size(x)));
127-
check_inferred=false,
128-
)
80+
test_rrule(gemv, t nothing, randn(T), randn(T, m, n), x; check_inferred=false)
12981
end
13082
end
13183
end

0 commit comments

Comments
 (0)