|
2 | 2 | @testset "dot" begin
|
3 | 3 | @testset "all entries" begin
|
4 | 4 | n = 10
|
5 |
| - x, y = randn(n), randn(n) |
6 |
| - ẋ, ẏ = randn(n), randn(n) |
7 |
| - x̄, ȳ = randn(n), randn(n) |
8 |
| - frule_test(BLAS.dot, (x, ẋ), (y, ẏ)) |
9 |
| - rrule_test(BLAS.dot, randn(), (x, x̄), (y, ȳ)) |
| 5 | + test_frule(BLAS.dot, randn(n), randn(n)) |
| 6 | + test_rrule(BLAS.dot, randn(n), randn(n)) |
10 | 7 | end
|
11 | 8 |
|
12 | 9 | @testset "over strides" begin
|
13 | 10 | n = 10
|
14 | 11 | incx = 2
|
15 | 12 | incy = 3
|
16 |
| - x, y = randn(n * incx), randn(n * incy) |
17 |
| - x̄, ȳ = randn(n * incx), randn(n * incy) |
18 |
| - rrule_test( |
| 13 | + test_rrule( |
19 | 14 | BLAS.dot,
|
20 |
| - randn(), |
21 |
| - (n, nothing), |
22 |
| - (x, x̄), |
23 |
| - (incx, nothing), |
24 |
| - (y, ȳ), |
25 |
| - (incy, nothing), |
| 15 | + n ⊢ nothing, |
| 16 | + randn(n * incx), |
| 17 | + incx ⊢ nothing, |
| 18 | + randn(n * incy), |
| 19 | + incy ⊢ nothing, |
26 | 20 | )
|
27 | 21 | end
|
28 | 22 | end
|
29 | 23 |
|
30 | 24 | @testset "nrm2" begin
|
31 | 25 | @testset "all entries" begin
|
32 |
| - @testset "$T" for T in (Float64,ComplexF64) |
| 26 | + @testset "$T" for T in (Float64, ComplexF64) |
33 | 27 | n = 10
|
34 |
| - x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n) |
35 |
| - frule_test(BLAS.nrm2, (x, ẋ)) |
36 |
| - rrule_test(BLAS.nrm2, randn(), (x, x̄)) |
| 28 | + test_frule(BLAS.nrm2, randn(T, n)) |
| 29 | + test_rrule(BLAS.nrm2, randn(T, n)) |
37 | 30 | end
|
38 | 31 | end
|
39 | 32 |
|
40 | 33 | @testset "over strides" begin
|
41 | 34 | dims = (3, 2, 1)
|
42 | 35 | incx = 2
|
43 |
| - @testset "Array{$T,$N}" for N in 1:length(dims), T in (Float64,ComplexF64) |
| 36 | + @testset "Array{$T,$N}" for N in eachindex(dims), T in (Float64, ComplexF64) |
44 | 37 | s = (dims[1] * incx, dims[2:N]...)
|
45 | 38 | n = div(prod(s), incx)
|
46 |
| - x, x̄ = randn(T, s), randn(T, s) |
47 |
| - rrule_test( |
48 |
| - BLAS.nrm2, |
49 |
| - randn(), |
50 |
| - (n, nothing), |
51 |
| - (x, x̄), |
52 |
| - (incx, nothing); |
53 |
| - atol=0, |
54 |
| - rtol=1e-5, |
| 39 | + test_rrule( |
| 40 | + BLAS.nrm2, n ⊢ nothing, randn(T, s), incx ⊢ nothing; atol=0, rtol=1e-5, |
55 | 41 | )
|
56 | 42 | end
|
57 | 43 | end
|
58 | 44 | end
|
59 | 45 |
|
60 | 46 | @testset "asum" begin
|
61 | 47 | @testset "all entries" begin
|
62 |
| - @testset "$T" for T in (Float64,ComplexF64) |
| 48 | + @testset "$T" for T in (Float64, ComplexF64) |
63 | 49 | n = 6
|
64 |
| - x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n) |
65 |
| - frule_test(BLAS.asum, (x, ẋ)) |
66 |
| - rrule_test(BLAS.asum, randn(), (x, x̄)) |
| 50 | + test_frule(BLAS.asum, randn(T, n)) |
| 51 | + test_rrule(BLAS.asum, randn(T, n)) |
67 | 52 | end
|
68 | 53 | end
|
69 | 54 |
|
70 | 55 | @testset "over strides" begin
|
71 | 56 | dims = (2, 2, 1)
|
72 | 57 | incx = 2
|
73 |
| - @testset "Array{$T,$N}" for N in 1:length(dims), T in (Float64,ComplexF64) |
| 58 | + @testset "Array{$T,$N}" for N in eachindex(dims), T in (Float64, ComplexF64) |
74 | 59 | s = (dims[1] * incx, dims[2:N]...)
|
75 | 60 | n = div(prod(s), incx)
|
76 |
| - x, x̄ = randn(T, s), randn(T, s) |
77 |
| - rrule_test( BLAS.asum, randn(), (n, nothing), (x, x̄), (incx, nothing)) |
| 61 | + test_rrule( BLAS.asum, n ⊢ nothing, randn(T, s), incx ⊢ nothing) |
78 | 62 | end
|
79 | 63 | end
|
80 | 64 | end
|
81 | 65 |
|
82 | 66 | @testset "gemm" begin
|
83 |
| - dims = 3:5 |
84 |
| - for m in dims, n in dims, p in dims, tA in ('N', 'C', 'T'), tB in ('N', 'C', 'T'), T in (Float64, ComplexF64) |
85 |
| - α = randn(T) |
| 67 | + for m in 3:5, n in 3:5, p in 3:5, tA in ('N', 'C', 'T'), tB in ('N', 'C', 'T'), T in (Float64, ComplexF64) |
86 | 68 | A = randn(T, tA === 'N' ? (m, n) : (n, m))
|
87 | 69 | B = randn(T, tB === 'N' ? (n, p) : (p, n))
|
88 |
| - C = gemm(tA, tB, α, A, B) |
89 |
| - ȳ = randn(T, size(C)...) |
90 |
| - rrule_test( |
91 |
| - gemm, |
92 |
| - ȳ, |
93 |
| - (tA, nothing), |
94 |
| - (tB, nothing), |
95 |
| - (α, randn(T)), |
96 |
| - (A, randn(T, size(A))), |
97 |
| - (B, randn(T, size(B))); |
98 |
| - check_inferred=false, |
99 |
| - ) |
100 |
| - |
101 |
| - rrule_test( |
102 |
| - gemm, |
103 |
| - ȳ, |
104 |
| - (tA, nothing), |
105 |
| - (tB, nothing), |
106 |
| - (A, randn(T, size(A))), |
107 |
| - (B, randn(T, size(B))); |
108 |
| - check_inferred=false, |
| 70 | + test_rrule(gemm, tA ⊢ nothing, tB ⊢ nothing, A, B; check_inferred=false) |
| 71 | + test_rrule( # 5 arg version with scaling scalar |
| 72 | + gemm, tA ⊢ nothing, tB ⊢ nothing, randn(T), A, B; check_inferred=false, |
109 | 73 | )
|
110 | 74 | end
|
111 | 75 | end
|
112 | 76 |
|
113 | 77 | @testset "gemv" begin
|
114 | 78 | for n in 3:5, m in 3:5, t in ('N', 'C', 'T'), T in (Float64, ComplexF64)
|
115 |
| - α = randn(T) |
116 |
| - A = randn(T, m, n) |
117 | 79 | x = randn(T, t === 'N' ? n : m)
|
118 |
| - y = gemv(t, α, A, x) |
119 |
| - ȳ = randn(T, size(y)...) |
120 |
| - rrule_test( |
121 |
| - gemv, |
122 |
| - ȳ, |
123 |
| - (t, nothing), |
124 |
| - (α, randn(T)), |
125 |
| - (A, randn(T, size(A))), |
126 |
| - (x, randn(T, size(x))); |
127 |
| - check_inferred=false, |
128 |
| - ) |
| 80 | + test_rrule(gemv, t ⊢ nothing, randn(T), randn(T, m, n), x; check_inferred=false) |
129 | 81 | end
|
130 | 82 | end
|
131 | 83 | end
|
0 commit comments