Skip to content

Commit 2e169af

Browse files
committed
Reduce allocations in diagonal tests
1 parent b265fea commit 2e169af

File tree

1 file changed

+96
-76
lines changed

1 file changed

+96
-76
lines changed

test/diagonal.jl

Lines changed: 96 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
178178
@test D*v DM*v atol=n*eps(relty)*(1+(elty<:Complex))
179179
@test D*U DM*U atol=n^2*eps(relty)*(1+(elty<:Complex))
180180

181-
@test transpose(U)*D transpose(U)*Array(D)
182-
@test U'*D U'*Array(D)
181+
@test transpose(U)*D transpose(U)*M
182+
@test U'*D U'*M
183183

184184
if relty != BigFloat
185185
atol_two = 2n^2 * eps(relty) * (1 + (elty <: Complex))
@@ -214,12 +214,12 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
214214
@test_throws DimensionMismatch ldiv!(D, fill(elty(1), n + 1))
215215
@test_throws SingularException ldiv!(Diagonal(zeros(relty, n)), copy(v))
216216
b = rand(elty, n, n)
217-
@test ldiv!(D, copy(b)) Array(D)\Array(b)
217+
@test ldiv!(D, copy(b)) M\b
218218
@test_throws SingularException ldiv!(Diagonal(zeros(elty, n)), copy(b))
219219
b = view(rand(elty, n), Vector(1:n))
220220
b2 = copy(b)
221221
c = ldiv!(D, b)
222-
d = Array(D)\b2
222+
d = M\b2
223223
@test c d
224224
@test_throws SingularException ldiv!(Diagonal(zeros(elty, n)), b)
225225
b = rand(elty, n+1, n+1)
@@ -256,17 +256,17 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
256256

257257
if relty <: BlasFloat
258258
for b in (rand(elty,n,n), rand(elty,n))
259-
@test lmul!(copy(D), copy(b)) Array(D)*Array(b)
260-
@test lmul!(transpose(copy(D)), copy(b)) transpose(Array(D))*Array(b)
261-
@test lmul!(adjoint(copy(D)), copy(b)) Array(D)'*Array(b)
259+
@test lmul!(copy(D), copy(b)) M*b
260+
@test lmul!(transpose(copy(D)), copy(b)) transpose(M)*b
261+
@test lmul!(adjoint(copy(D)), copy(b)) M'*b
262262
end
263263
end
264264

265265
#a few missing mults
266266
bd = Bidiagonal(D2)
267-
@test D*transpose(D2) Array(D)*transpose(Array(D2))
268-
@test D2*transpose(D) Array(D2)*transpose(Array(D))
269-
@test D2*D' Array(D2)*Array(D)'
267+
@test D*transpose(D2) M*transpose(DM2)
268+
@test D2*transpose(D) DM2*transpose(M)
269+
@test D2*D' DM2*M'
270270

271271
#division of two Diagonals
272272
@test D/D2 Diagonal(D.diag./D2.diag)
@@ -281,33 +281,36 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
281281
A = rand(elty, n, n)
282282
Asym = Symmetric(A + transpose(A), :U)
283283
Aherm = Hermitian(A + adjoint(A), :U)
284+
Msym = Array(Asym)
285+
Mherm = Array(Aherm)
284286
for op in (+, -)
285287
@test op(Asym, D) isa Symmetric
286-
@test Array(op(Asym, D)) Array(Symmetric(op(Array(Asym), Array(D))))
288+
@test Array(op(Asym, D)) Array(Symmetric(op(Msym, M)))
287289
@test op(D, Asym) isa Symmetric
288-
@test Array(op(D, Asym)) Array(Symmetric(op(Array(D), Array(Asym))))
290+
@test Array(op(D, Asym)) Array(Symmetric(op(M, Msym)))
289291
if !(elty <: Real)
290292
Dr = real(D)
293+
Mr = Array(Dr)
291294
@test op(Aherm, Dr) isa Hermitian
292-
@test Array(op(Aherm, Dr)) Array(Hermitian(op(Array(Aherm), Array(Dr))))
295+
@test Array(op(Aherm, Dr)) Array(Hermitian(op(Mherm, Mr)))
293296
@test op(Dr, Aherm) isa Hermitian
294-
@test Array(op(Dr, Aherm)) Array(Hermitian(op(Array(Dr), Array(Aherm))))
297+
@test Array(op(Dr, Aherm)) Array(Hermitian(op(Mr, Mherm)))
295298
end
296299
end
297-
@test Array(D*transpose(Asym)) Array(D) * Array(transpose(Asym))
298-
@test Array(D*adjoint(Asym)) Array(D) * Array(adjoint(Asym))
299-
@test Array(D*transpose(Aherm)) Array(D) * Array(transpose(Aherm))
300-
@test Array(D*adjoint(Aherm)) Array(D) * Array(adjoint(Aherm))
300+
@test Array(D*transpose(Asym)) M * Array(transpose(Asym))
301+
@test Array(D*adjoint(Asym)) M * Array(adjoint(Asym))
302+
@test Array(D*transpose(Aherm)) M * Array(transpose(Aherm))
303+
@test Array(D*adjoint(Aherm)) M * Array(adjoint(Aherm))
301304
@test Array(transpose(Asym)*transpose(D)) Array(transpose(Asym)) * Array(transpose(D))
302305
@test Array(transpose(D)*transpose(Asym)) Array(transpose(D)) * Array(transpose(Asym))
303306
@test Array(adjoint(Aherm)*adjoint(D)) Array(adjoint(Aherm)) * Array(adjoint(D))
304307
@test Array(adjoint(D)*adjoint(Aherm)) Array(adjoint(D)) * Array(adjoint(Aherm))
305308

306309
# Performance specialisations for A*_mul_B!
307310
vvv = similar(vv)
308-
@test (r = Matrix(D) * vv ; mul!(vvv, D, vv) r vvv)
309-
@test (r = Matrix(D)' * vv ; mul!(vvv, adjoint(D), vv) r vvv)
310-
@test (r = transpose(Matrix(D)) * vv ; mul!(vvv, transpose(D), vv) r vvv)
311+
@test (r = M * vv ; mul!(vvv, D, vv) r vvv)
312+
@test (r = M' * vv ; mul!(vvv, adjoint(D), vv) r vvv)
313+
@test (r = transpose(M) * vv ; mul!(vvv, transpose(D), vv) r vvv)
311314

312315
UUU = similar(UU)
313316
for transformA in (identity, adjoint, transpose)
@@ -319,55 +322,62 @@ LinearAlgebra.istril(N::NotDiagonal) = istril(N.a)
319322

320323
alpha = elty(randn()) # randn(elty) does not work with BigFloat
321324
beta = elty(randn())
322-
@test begin
325+
@testset begin
323326
vvv = similar(vv)
324327
vvv .= randn(size(vvv)) # randn!(vvv) does not work with BigFloat
325-
r = alpha * Matrix(D) * vv + beta * vvv
326-
mul!(vvv, D, vv, alpha, beta) r vvv
328+
r = alpha * M * vv + beta * vvv
329+
@test mul!(vvv, D, vv, alpha, beta) === vvv
330+
@test r vvv
327331
end
328-
@test begin
332+
@testset begin
329333
vvv = similar(vv)
330334
vvv .= randn(size(vvv)) # randn!(vvv) does not work with BigFloat
331-
r = alpha * Matrix(D)' * vv + beta * vvv
332-
mul!(vvv, adjoint(D), vv, alpha, beta) r vvv
335+
r = alpha * M' * vv + beta * vvv
336+
@test mul!(vvv, adjoint(D), vv, alpha, beta) === vvv
337+
@test r vvv
333338
end
334-
@test begin
339+
@testset begin
335340
vvv = similar(vv)
336341
vvv .= randn(size(vvv)) # randn!(vvv) does not work with BigFloat
337-
r = alpha * transpose(Matrix(D)) * vv + beta * vvv
338-
mul!(vvv, transpose(D), vv, alpha, beta) r vvv
342+
r = alpha * transpose(M) * vv + beta * vvv
343+
@test mul!(vvv, transpose(D), vv, alpha, beta) === vvv
344+
@test r vvv
339345
end
340346

341-
@test begin
347+
@testset begin
342348
UUU = similar(UU)
343349
UUU .= randn(size(UUU)) # randn!(UUU) does not work with BigFloat
344-
r = alpha * Matrix(D) * UU + beta * UUU
345-
mul!(UUU, D, UU, alpha, beta) r UUU
350+
r = alpha * M * UU + beta * UUU
351+
@test mul!(UUU, D, UU, alpha, beta) === UUU
352+
@test r UUU
346353
end
347-
@test begin
354+
@testset begin
348355
UUU = similar(UU)
349356
UUU .= randn(size(UUU)) # randn!(UUU) does not work with BigFloat
350-
r = alpha * Matrix(D)' * UU + beta * UUU
351-
mul!(UUU, adjoint(D), UU, alpha, beta) r UUU
357+
r = alpha * M' * UU + beta * UUU
358+
@test mul!(UUU, adjoint(D), UU, alpha, beta) === UUU
359+
@test r UUU
352360
end
353-
@test begin
361+
@testset begin
354362
UUU = similar(UU)
355363
UUU .= randn(size(UUU)) # randn!(UUU) does not work with BigFloat
356-
r = alpha * transpose(Matrix(D)) * UU + beta * UUU
357-
mul!(UUU, transpose(D), UU, alpha, beta) r UUU
364+
r = alpha * transpose(M) * UU + beta * UUU
365+
@test mul!(UUU, transpose(D), UU, alpha, beta) === UUU
366+
@test r UUU
358367
end
359368

360369
# make sure that mul!(A, {Adj|Trans}(B)) works with B as a Diagonal
361370
VV = Array(D)
362-
DD = copy(D)
363-
r = VV * Matrix(D)
364-
@test Array(rmul!(VV, DD)) r Array(D)*Array(D)
365-
DD = copy(D)
366-
r = VV * transpose(Array(D))
367-
@test Array(rmul!(VV, transpose(DD))) r
368-
DD = copy(D)
369-
r = VV * Array(D)'
370-
@test Array(rmul!(VV, adjoint(DD))) r
371+
r = VV * M
372+
@test rmul!(VV, D) r M*M
373+
if transpose(D) !== D
374+
r = VV * transpose(M)
375+
@test rmul!(VV, transpose(D)) r
376+
end
377+
if adjoint(D) !== D
378+
r = VV * M'
379+
@test rmul!(VV, adjoint(D)) r
380+
end
371381

372382
# kron
373383
D3 = Diagonal(convert(Vector{elty}, rand(n÷2)))
@@ -545,16 +555,17 @@ Base.size(x::SimpleVector) = size(x.vec)
545555

546556
@testset "kron (issue #46456)" for repr in Any[identity, SimpleVector]
547557
A = Diagonal(repr(randn(10)))
558+
M = Array(A)
548559
BL = Bidiagonal(repr(randn(10)), repr(randn(9)), :L)
549560
BU = Bidiagonal(repr(randn(10)), repr(randn(9)), :U)
550561
C = SymTridiagonal(repr(randn(10)), repr(randn(9)))
551562
Cl = SymTridiagonal(repr(randn(10)), repr(randn(10)))
552563
D = Tridiagonal(repr(randn(9)), repr(randn(10)), repr(randn(9)))
553-
@test kron(A, BL)::Bidiagonal == kron(Array(A), Array(BL))
554-
@test kron(A, BU)::Bidiagonal == kron(Array(A), Array(BU))
555-
@test kron(A, C)::SymTridiagonal == kron(Array(A), Array(C))
556-
@test kron(A, Cl)::SymTridiagonal == kron(Array(A), Array(Cl))
557-
@test kron(A, D)::Tridiagonal == kron(Array(A), Array(D))
564+
@test kron(A, BL)::Bidiagonal == kron(M, Array(BL))
565+
@test kron(A, BU)::Bidiagonal == kron(M, Array(BU))
566+
@test kron(A, C)::SymTridiagonal == kron(M, Array(C))
567+
@test kron(A, Cl)::SymTridiagonal == kron(M, Array(Cl))
568+
@test kron(A, D)::Tridiagonal == kron(M, Array(D))
558569
end
559570

560571
@testset "svdvals and eigvals (#11120/#11247)" begin
@@ -627,9 +638,10 @@ end
627638

628639
@testset "Test reverse" begin
629640
D = Diagonal(randn(5))
630-
@test reverse(D, dims=1) == reverse(Matrix(D), dims=1)
631-
@test reverse(D, dims=2) == reverse(Matrix(D), dims=2)
632-
@test reverse(D)::Diagonal == reverse(Matrix(D))
641+
M = Matrix(D)
642+
@test reverse(D, dims=1) == reverse(M, dims=1)
643+
@test reverse(D, dims=2) == reverse(M, dims=2)
644+
@test reverse(D)::Diagonal == reverse(M)
633645
end
634646

635647
@testset "inverse" begin
@@ -645,8 +657,9 @@ end
645657
@testset "pseudoinverse" begin
646658
for d in Any[randn(n), zeros(n), Int[], [0, 2, 0.003], [0im, 1+2im, 0.003im], [0//1, 2//1, 3//100], [0//1, 1//1+2im, 3im//100]]
647659
D = Diagonal(d)
648-
@test pinv(D) pinv(Array(D))
649-
@test pinv(D, 1.0e-2) pinv(Array(D), 1.0e-2)
660+
M = Array(D)
661+
@test pinv(D) pinv(M)
662+
@test pinv(D, 1.0e-2) pinv(M, 1.0e-2)
650663
end
651664
end
652665

@@ -662,51 +675,54 @@ end
662675
@test Matrix(1.0I, 5, 5) \ Diagonal(fill(1.,5)) == Matrix(I, 5, 5)
663676

664677
@testset "Triangular and Diagonal" begin
665-
function _test_matrix(type)
678+
function _randomarray(type, ::Val{N} = Val(2)) where {N}
679+
sz = ntuple(_->5, N)
666680
if type == Int
667-
return rand(1:9, 5, 5)
681+
return rand(1:9, sz...)
668682
else
669-
return randn(type, 5, 5)
683+
return randn(type, sz...)
670684
end
671685
end
672686
types = (Float64, Int, ComplexF64)
673687
for ta in types
674-
D = Diagonal(_test_matrix(ta))
688+
D = Diagonal(_randomarray(ta, Val(1)))
689+
M = Matrix(D)
675690
for tb in types
676-
B = _test_matrix(tb)
691+
B = _randomarray(tb, Val(2))
677692
Tmats = (LowerTriangular(B), UnitLowerTriangular(B), UpperTriangular(B), UnitUpperTriangular(B))
678693
restypes = (LowerTriangular, LowerTriangular, UpperTriangular, UpperTriangular)
679694
for (T, rtype) in zip(Tmats, restypes)
680695
adjtype = (rtype == LowerTriangular) ? UpperTriangular : LowerTriangular
681696

682697
# Triangular * Diagonal
683698
R = T * D
684-
@test R Array(T) * Array(D)
699+
TA = Array(T)
700+
@test R TA * M
685701
@test isa(R, rtype)
686702

687703
# Diagonal * Triangular
688704
R = D * T
689-
@test R Array(D) * Array(T)
705+
@test R M * TA
690706
@test isa(R, rtype)
691707

692708
# Adjoint of Triangular * Diagonal
693709
R = T' * D
694-
@test R Array(T)' * Array(D)
710+
@test R TA' * M
695711
@test isa(R, adjtype)
696712

697713
# Diagonal * Adjoint of Triangular
698714
R = D * T'
699-
@test R Array(D) * Array(T)'
715+
@test R M * TA'
700716
@test isa(R, adjtype)
701717

702718
# Transpose of Triangular * Diagonal
703719
R = transpose(T) * D
704-
@test R transpose(Array(T)) * Array(D)
720+
@test R transpose(TA) * M
705721
@test isa(R, adjtype)
706722

707723
# Diagonal * Transpose of Triangular
708724
R = D * transpose(T)
709-
@test R Array(D) * transpose(Array(T))
725+
@test R M * transpose(TA)
710726
@test isa(R, adjtype)
711727
end
712728
end
@@ -1333,7 +1349,7 @@ end
13331349
end
13341350

13351351
@testset "diagonal triple multiplication (#49005)" begin
1336-
n = 10
1352+
local n = 10
13371353
@test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n))) isa Diagonal
13381354
@test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n+1))))
13391355
@test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n+1), Diagonal(ones(n+1))))
@@ -1449,10 +1465,12 @@ end
14491465
for p in ([1 2; 3 4], [1 2+im; 2-im 4+2im])
14501466
m = SizedArrays.SizedArray{(2,2)}(p)
14511467
D = Diagonal(fill(m, 2))
1468+
M = Matrix(D)
14521469
for T in (Symmetric, Hermitian)
14531470
S = T(fill(m, 2, 2))
1454-
@test D + S == Array(D) + Array(S)
1455-
@test S + D == Array(S) + Array(D)
1471+
SA = Array(S)
1472+
@test D + S == M + SA
1473+
@test S + D == SA + M
14561474
end
14571475
end
14581476
end
@@ -1464,12 +1482,14 @@ end
14641482

14651483
@testset "zeros in kron with block matrices" begin
14661484
D = Diagonal(1:4)
1485+
M = Matrix(D)
14671486
B = reshape([ones(2,2), ones(3,2), ones(2,3), ones(3,3)], 2, 2)
1468-
@test kron(D, B) == kron(Array(D), B)
1469-
@test kron(B, D) == kron(B, Array(D))
1487+
@test kron(D, B) == kron(M, B)
1488+
@test kron(B, D) == kron(B, M)
14701489
D2 = Diagonal([ones(2,2), ones(3,3)])
1471-
@test kron(D, D2) == kron(D, Array{eltype(D2)}(D2))
1472-
@test kron(D2, D) == kron(Array{eltype(D2)}(D2), D)
1490+
M2 = Array{eltype(D2)}(D2)
1491+
@test kron(D, D2) == kron(D, M2)
1492+
@test kron(D2, D) == kron(M2, D)
14731493
end
14741494

14751495
@testset "opnorms" begin

0 commit comments

Comments
 (0)