Skip to content

Commit 449a3e5

Browse files
authored
Handle more multiplications with AbstractQs (#117)
1 parent 679805a commit 449a3e5

File tree

4 files changed

+36
-11
lines changed

4 files changed

+36
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArrayLayouts"
22
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
33
authors = ["Sheehan Olver <[email protected]>"]
4-
version = "0.8.17"
4+
version = "0.8.18"
55

66
[deps]
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

src/ArrayLayouts.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ using Base.Broadcast: Broadcasted
1515

1616
import Base.Broadcast: BroadcastStyle, broadcastable, instantiate, materialize, materialize!
1717

18-
using LinearAlgebra: AbstractTriangular, AbstractQ, QRCompactWYQ, QRPackedQ, checksquare,
19-
pinv, tilebufsize, cholcopy,
18+
using LinearAlgebra: AbstractQ, QRCompactWYQ, QRPackedQ, HessenbergQ,
19+
AbstractTriangular, checksquare, pinv, tilebufsize, cholcopy,
2020
norm2, norm1, normInf, normMinusInf,
2121
AdjOrTrans, HermOrSym, RealHermSymComplexHerm, AdjointAbsVec, TransposeAbsVec,
2222
checknonsingular, _apply_ipiv_rows!, ipiv2perm, chkfullrank

src/mul.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,19 +88,26 @@ check_mul_axes(A) = nothing
8888
_check_mul_axes(::Number, ::Number) = nothing
8989
_check_mul_axes(::Number, _) = nothing
9090
_check_mul_axes(_, ::Number) = nothing
91-
_check_mul_axes(A, B) = axes(A,2) == axes(B,1) || throw(DimensionMismatch("Second axis of A, $(axes(A,2)), and first axis of B, $(axes(B,1)) must match"))
91+
_check_mul_axes(A, B) = axes(A, 2) == axes(B, 1) || throw(DimensionMismatch("Second axis of A, $(axes(A,2)), and first axis of B, $(axes(B,1)) must match"))
92+
# we need to special case AbstractQ as it allows non-compatiple multiplication
93+
const FlexibleLeftQs = Union{QRCompactWYQ,QRPackedQ,HessenbergQ}
94+
_check_mul_axes(::FlexibleLeftQs, ::Number) = nothing
95+
_check_mul_axes(Q::FlexibleLeftQs, B) =
96+
axes(Q.factors, 1) == axes(B, 1) || axes(Q.factors, 2) == axes(B, 1) ||
97+
throw(DimensionMismatch("First axis of B, $(axes(B,1)) must match either axes of A, $(axes(Q.factors))"))
98+
_check_mul_axes(::Number, ::AdjointQtype{<:Any,<:FlexibleLeftQs}) = nothing
99+
function _check_mul_axes(A, adjQ::AdjointQtype{<:Any,<:FlexibleLeftQs})
100+
Q = parent(adjQ)
101+
axes(A, 2) == axes(Q.factors, 1) || axes(A, 2) == axes(Q.factors, 2) ||
102+
throw(DimensionMismatch("Second axis of A, $(axes(A,2)) must match either axes of B, $(axes(Q.factors))"))
103+
end
104+
_check_mul_axes(Q::FlexibleLeftQs, adjQ::AdjointQtype{<:Any,<:FlexibleLeftQs}) =
105+
invoke(_check_mul_axes, Tuple{Any,Any}, Q, adjQ)
92106
function check_mul_axes(A, B, C...)
93107
_check_mul_axes(A, B)
94108
check_mul_axes(B, C...)
95109
end
96110

97-
# we need to special case AbstractQ as it allows non-compatiple multiplication
98-
function check_mul_axes(A::Union{QRCompactWYQ,QRPackedQ}, B, C...)
99-
axes(A.factors, 1) == axes(B, 1) || axes(A.factors, 2) == axes(B, 1) ||
100-
throw(DimensionMismatch("First axis of B, $(axes(B,1)) must match either axes of A, $(axes(A))"))
101-
check_mul_axes(B, C...)
102-
end
103-
104111
function instantiate(M::Mul)
105112
@boundscheck check_mul_axes(M.A, M.B)
106113
M

test/test_muladd.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,12 @@ Random.seed!(0)
615615
Q = qr(randn(5,5)).Q
616616
b = randn(5)
617617
B = randn(5,5)
618+
@test Q*1.0 == ArrayLayouts.lmul!(Q, Matrix{Float64}(I, 5, 5))
618619
@test Q*b == ArrayLayouts.lmul!(Q, copy(b)) == mul(Q,b)
619620
@test Q*B == ArrayLayouts.lmul!(Q, copy(B)) == mul(Q,B)
620621
@test B*Q == ArrayLayouts.rmul!(copy(B), Q) == mul(B,Q)
622+
@test 1.0*Q ArrayLayouts.rmul!(Matrix{Float64}(I, 5, 5), Q)
623+
@test 1.0*Q' ArrayLayouts.rmul!(Matrix{Float64}(I, 5, 5), Q')
621624
@test Q*Q mul(Q,Q)
622625
@test Q'*b == ArrayLayouts.lmul!(Q', copy(b)) == mul(Q',b)
623626
@test Q'*B == ArrayLayouts.lmul!(Q', copy(B)) == mul(Q',B)
@@ -627,6 +630,21 @@ Random.seed!(0)
627630
@test Q'*Q mul(Q',Q)
628631
@test Q*UpperTriangular(B) mul(Q, UpperTriangular(B))
629632
@test UpperTriangular(B)*Q mul(UpperTriangular(B), Q)
633+
634+
Q = qr(randn(7,5)).Q
635+
b = randn(5)
636+
B = randn(5,5)
637+
@test Q*1.0 == ArrayLayouts.lmul!(Q, Matrix{Float64}(I, 7, 7))
638+
@test Q*b == mul(Q,b)
639+
@test Q*B == mul(Q,B)
640+
@test 1.0*Q ArrayLayouts.rmul!(Matrix{Float64}(I, 7, 7), Q)
641+
@test Q*Q mul(Q,Q)
642+
@test B*Q' == mul(B,Q')
643+
@test Q*Q' mul(Q,Q')
644+
@test Q'*Q' mul(Q',Q')
645+
@test Q'*Q mul(Q',Q)
646+
VERSION >= v"1.8-" && @test Q*UpperTriangular(B) mul(Q, UpperTriangular(B))
647+
@test UpperTriangular(B)*Q' mul(UpperTriangular(B), Q')
630648
end
631649

632650
@testset "Mul" begin

0 commit comments

Comments
 (0)