Skip to content

Commit 262b40a

Browse files
authored
Fix (l/r)mul! with Diagonal/Bidiagonal (#55052)
Currently, `rmul!(A::AbstractMatirx, D::Diagonal)` calls `mul!(A, A, D)`, but this isn't a valid call, as `mul!` assumes no aliasing between the destination and the matrices to be multiplied. As a consequence, ```julia julia> B = Bidiagonal(rand(4), rand(3), :L) 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.476892 ⋅ ⋅ ⋅ 0.353756 0.139188 ⋅ ⋅ ⋅ 0.685839 0.309336 ⋅ ⋅ ⋅ 0.369038 0.304273 julia> D = Diagonal(rand(size(B,2))); julia> rmul!(B, D) 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 julia> B 4×4 Bidiagonal{Float64, Vector{Float64}}: 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ⋅ ⋅ ⋅ 0.0 0.0 ``` This is clearly nonsense, and happens because the internal `_mul!` function assumes that it can safely overwrite the destination with zeros before carrying out the multiplication. This is fixed in this PR by using broadcasting instead. The current implementation is generally equally performant, albeit occasionally with a minor allocation arising from `reshape`ing an `Array`. A similar problem also exists in `l/rmul!` with `Bidiaognal`, but that's a little harder to fix while remaining equally performant.
1 parent faf17eb commit 262b40a

File tree

6 files changed

+183
-4
lines changed

6 files changed

+183
-4
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
470470
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
471471
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))
472472

473-
lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
474-
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
473+
# B .= A * B
474+
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
475+
_muldiag_size_check(A, B)
476+
(; dv, ev) = A
477+
if A.uplo == 'U'
478+
for k in axes(B,2)
479+
for i in axes(ev,1)
480+
B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k]
481+
end
482+
B[end,k] = dv[end] * B[end,k]
483+
end
484+
else
485+
for k in axes(B,2)
486+
for i in reverse(axes(dv,1)[2:end])
487+
B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k]
488+
end
489+
B[1,k] = dv[1] * B[1,k]
490+
end
491+
end
492+
return B
493+
end
494+
# B .= D * B
495+
function lmul!(D::Diagonal, B::Bidiagonal)
496+
_muldiag_size_check(D, B)
497+
(; dv, ev) = B
498+
isL = B.uplo == 'L'
499+
dv[1] = D.diag[1] * dv[1]
500+
for i in axes(ev,1)
501+
ev[i] = D.diag[i + isL] * ev[i]
502+
dv[i+1] = D.diag[i+1] * dv[i+1]
503+
end
504+
return B
505+
end
506+
# B .= B * A
507+
function rmul!(B::AbstractMatrix, A::Bidiagonal)
508+
_muldiag_size_check(A, B)
509+
(; dv, ev) = A
510+
if A.uplo == 'U'
511+
for k in reverse(axes(dv,1)[2:end])
512+
for i in axes(B,1)
513+
B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1]
514+
end
515+
end
516+
for i in axes(B,1)
517+
B[i,1] *= dv[1]
518+
end
519+
else
520+
for k in axes(ev,1)
521+
for i in axes(B,1)
522+
B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k]
523+
end
524+
end
525+
for i in axes(B,1)
526+
B[i,end] *= dv[end]
527+
end
528+
end
529+
return B
530+
end
531+
# B .= B * D
532+
function rmul!(B::Bidiagonal, D::Diagonal)
533+
_muldiag_size_check(B, D)
534+
(; dv, ev) = B
535+
isU = B.uplo == 'U'
536+
dv[1] *= D.diag[1]
537+
for i in axes(ev,1)
538+
ev[i] *= D.diag[i + isU]
539+
dv[i+1] *= D.diag[i+1]
540+
end
541+
return B
542+
end
475543

476544
function check_A_mul_B!_sizes(C, A, B)
477545
mA, nA = size(A)

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,49 @@ function (*)(D::Diagonal, V::AbstractVector)
327327
return D.diag .* V
328328
end
329329

330-
rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
331-
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
330+
function rmul!(A::AbstractMatrix, D::Diagonal)
331+
_muldiag_size_check(A, D)
332+
for I in CartesianIndices(A)
333+
row, col = Tuple(I)
334+
@inbounds A[row, col] *= D.diag[col]
335+
end
336+
return A
337+
end
338+
# T .= T * D
339+
function rmul!(T::Tridiagonal, D::Diagonal)
340+
_muldiag_size_check(T, D)
341+
(; dl, d, du) = T
342+
d[1] *= D.diag[1]
343+
for i in axes(dl,1)
344+
dl[i] *= D.diag[i]
345+
du[i] *= D.diag[i+1]
346+
d[i+1] *= D.diag[i+1]
347+
end
348+
return T
349+
end
350+
351+
function lmul!(D::Diagonal, B::AbstractVecOrMat)
352+
_muldiag_size_check(D, B)
353+
for I in CartesianIndices(B)
354+
row = I[1]
355+
@inbounds B[I] = D.diag[row] * B[I]
356+
end
357+
return B
358+
end
359+
360+
# in-place multiplication with a diagonal
361+
# T .= D * T
362+
function lmul!(D::Diagonal, T::Tridiagonal)
363+
_muldiag_size_check(D, T)
364+
(; dl, d, du) = T
365+
d[1] = D.diag[1] * d[1]
366+
for i in axes(dl,1)
367+
dl[i] = D.diag[i+1] * dl[i]
368+
du[i] = D.diag[i] * du[i]
369+
d[i+1] = D.diag[i+1] * d[i+1]
370+
end
371+
return T
372+
end
332373

333374
function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
334375
require_one_based_indexing(out, B)

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,41 @@ end
933933
@test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)]
934934
end
935935

936+
@testset "rmul!/lmul! with banded matrices" begin
937+
dv, ev = rand(4), rand(3)
938+
for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L))
939+
@testset "$(nameof(typeof(B)))" for B in (
940+
Bidiagonal(dv, ev, :U),
941+
Bidiagonal(dv, ev, :L),
942+
Diagonal(dv)
943+
)
944+
@test_throws ArgumentError rmul!(B, A)
945+
@test_throws ArgumentError lmul!(A, B)
946+
end
947+
D = Diagonal(dv)
948+
@test rmul!(copy(A), D) A * D
949+
@test lmul!(D, copy(A)) D * A
950+
end
951+
@testset "non-commutative" begin
952+
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
953+
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
954+
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
955+
for uplo in (:L, :U)
956+
B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo)
957+
D = Diagonal(fill(S22, size(B,2)))
958+
@test rmul!(copy(B), D) B * D
959+
D = Diagonal(fill(S33, size(B,1)))
960+
@test lmul!(D, copy(B)) D * B
961+
end
962+
963+
B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U)
964+
D = Diagonal(fill(S32, 4))
965+
@test lmul!(B, Array(D)) B * D
966+
B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U)
967+
@test rmul!(Array(D), B) D * B
968+
end
969+
end
970+
936971
@testset "conversion to Tridiagonal for immutable bands" begin
937972
n = 4
938973
dv = FillArrays.Fill(3, n)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,4 +1322,17 @@ end
13221322
@test M == D
13231323
end
13241324

1325+
@testset "rmul!/lmul! with banded matrices" begin
1326+
@testset "$(nameof(typeof(B)))" for B in (
1327+
Bidiagonal(rand(4), rand(3), :L),
1328+
Tridiagonal(rand(3), rand(4), rand(3))
1329+
)
1330+
BA = Array(B)
1331+
D = Diagonal(rand(size(B,1)))
1332+
DA = Array(D)
1333+
@test rmul!(copy(B), D) B * D BA * DA
1334+
@test lmul!(D, copy(B)) D * B DA * BA
1335+
end
1336+
end
1337+
13251338
end # module TestDiagonal

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,4 +892,23 @@ end
892892
end
893893
end
894894

895+
@testset "rmul!/lmul! with banded matrices" begin
896+
dl, d, du = rand(3), rand(4), rand(3)
897+
A = Tridiagonal(dl, d, du)
898+
D = Diagonal(d)
899+
@test rmul!(copy(A), D) A * D
900+
@test lmul!(D, copy(A)) D * A
901+
902+
@testset "non-commutative" begin
903+
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
904+
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
905+
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
906+
T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3))
907+
D = Diagonal(fill(S22, size(T,2)))
908+
@test rmul!(copy(T), D) T * D
909+
D = Diagonal(fill(S33, size(T,1)))
910+
@test lmul!(D, copy(T)) D * T
911+
end
912+
end
913+
895914
end # module TestTridiagonal

test/testhelpers/SizedArrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1
2323
Base.last(r::SOneTo) = length(r)
2424
Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")")
2525

26+
Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s
27+
Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s
28+
2629
struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
2730
data::A
2831
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}

0 commit comments

Comments
 (0)