Skip to content

Commit 18d18dc

Browse files
authored
Add unwrapping mechanism for triangular mul and solves (#50058)
This adds an unwrapping mechanism to triangular matrices, basically following the BLAS example in terms of characters encoding wrappers. It mirrors the `AdjOrTransOrHermOrSym` mechanism closely. Packages that want to overload by storage type can overload `generic_trimatmul!` (and potentially `generic_matrimul!`). Note the similarity to `generic_matvecmul!` and `generic_matmatmul!`. There is, unfortunately, some added code due to the fact that lazy conjugate wrappers have a different "wrapper depth" compared to the classic, e.g., `*Triangular{<:Any,<:Adjoint}`. I believe that with this PR we cover all wrappers of typically dense matrices with the unwrapping mechanism. ~~An analogous approach could be applied to `ldiv!`, if that's of interest and of benefit to the ecosystem.~~
2 parents d215d91 + e67ddaa commit 18d18dc

File tree

6 files changed

+550
-504
lines changed

6 files changed

+550
-504
lines changed

stdlib/LinearAlgebra/src/adjtrans.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ end
6464
Adjoint(A) = Adjoint{Base.promote_op(adjoint,eltype(A)),typeof(A)}(A)
6565
Transpose(A) = Transpose{Base.promote_op(transpose,eltype(A)),typeof(A)}(A)
6666

67+
# TODO: remove, is already replaced by wrapperop
6768
"""
6869
adj_or_trans(::AbstractArray) -> adjoint|transpose|identity
6970
adj_or_trans(::Type{<:AbstractArray}) -> adjoint|transpose|identity
70-
7171
Return [`adjoint`](@ref) from an `Adjoint` type or object and
7272
[`transpose`](@ref) from a `Transpose` type or object. Otherwise,
7373
return [`identity`](@ref). Note that `Adjoint` and `Transpose` have
@@ -94,9 +94,15 @@ inplace_adj_or_trans(::Type{<:AbstractArray}) = copyto!
9494
inplace_adj_or_trans(::Type{<:Adjoint}) = adjoint!
9595
inplace_adj_or_trans(::Type{<:Transpose}) = transpose!
9696

97+
# unwraps Adjoint, Transpose, Symmetric, Hermitian
9798
_unwrap(A::Adjoint) = parent(A)
9899
_unwrap(A::Transpose) = parent(A)
99100

101+
# unwraps Adjoint and Transpose only
102+
_unwrap_at(A) = A
103+
_unwrap_at(A::Adjoint) = parent(A)
104+
_unwrap_at(A::Transpose) = parent(A)
105+
100106
Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent)
101107
Base.unaliascopy(A::Union{Adjoint,Transpose}) = typeof(A)(Base.unaliascopy(A.parent))
102108

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ function ldiv!(c::AbstractVecOrMat, A::Bidiagonal, b::AbstractVecOrMat)
768768
end
769769
ldiv!(A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) = @inline ldiv!(b, A, b)
770770
ldiv!(c::AbstractVecOrMat, A::AdjOrTrans{<:Any,<:Bidiagonal}, b::AbstractVecOrMat) =
771-
(t = adj_or_trans(A); _rdiv!(t(c), t(b), t(A)); return c)
771+
(t = wrapperop(A); _rdiv!(t(c), t(b), t(A)); return c)
772772

773773
### Generic promotion methods and fallbacks
774774
\(A::Bidiagonal, B::AbstractVecOrMat) = ldiv!(_initarray(\, eltype(A), eltype(B), B), A, B)
@@ -846,7 +846,7 @@ end
846846
rdiv!(A::AbstractMatrix, B::Bidiagonal) = @inline _rdiv!(A, A, B)
847847
rdiv!(A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) = @inline _rdiv!(A, A, B)
848848
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AdjOrTrans{<:Any,<:Bidiagonal}) =
849-
(t = adj_or_trans(B); ldiv!(t(C), t(B), t(A)); return C)
849+
(t = wrapperop(B); ldiv!(t(C), t(B), t(A)); return C)
850850

851851
/(A::AbstractMatrix, B::Bidiagonal) = _rdiv!(_initarray(/, eltype(A), eltype(B), A), A, B)
852852

stdlib/LinearAlgebra/src/hessenberg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,11 @@ for T = (:Number, :UniformScaling, :Diagonal)
132132
end
133133

134134
function *(H::UpperHessenberg, U::UpperOrUnitUpperTriangular)
135-
HH = _mulmattri!(_initarray(*, eltype(H), eltype(U), H), H, U)
135+
HH = mul!(_initarray(*, eltype(H), eltype(U), H), H, U)
136136
UpperHessenberg(HH)
137137
end
138138
function *(U::UpperOrUnitUpperTriangular, H::UpperHessenberg)
139-
HH = _multrimat!(_initarray(*, eltype(U), eltype(H), H), U, H)
139+
HH = mul!(_initarray(*, eltype(U), eltype(H), H), U, H)
140140
UpperHessenberg(HH)
141141
end
142142

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@ AdjOrTransStridedMat{T} = Union{Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:
88
StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{<:Any, <:StridedMatrix{T}}, Transpose{<:Any, <:StridedMatrix{T}}}
99
StridedMaybeAdjOrTransVecOrMat{T} = Union{StridedVecOrMat{T}, AdjOrTrans{<:Any, <:StridedVecOrMat{T}}}
1010

11-
_parent(A) = A
12-
_parent(A::Adjoint) = parent(A)
13-
_parent(A::Transpose) = parent(A)
14-
1511
matprod(x, y) = x*y + x*y
1612

1713
# dot products
@@ -115,14 +111,14 @@ end
115111
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
116112
TS = promote_type(eltype(A), eltype(B))
117113
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
118-
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
119-
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
114+
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
115+
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
120116
end
121117
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
122118
TS = promote_type(eltype(A), eltype(B))
123119
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
124-
wrapperop(A)(convert(AbstractArray{TS}, _parent(A))),
125-
wrapperop(B)(convert(AbstractArray{TS}, _parent(B))))
120+
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
121+
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
126122
end
127123

128124
# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
@@ -131,13 +127,13 @@ function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:Bla
131127
TS = promote_type(eltype(A), eltype(B))
132128
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
133129
convert(AbstractArray{TS}, A),
134-
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
130+
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
135131
end
136132
function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
137133
TS = promote_type(eltype(A), eltype(B))
138134
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
139135
copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
140-
wrapperop(B)(convert(AbstractArray{real(TS)}, _parent(B))))
136+
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
141137
end
142138
# the following case doesn't seem to benefit from the translation A*B = (B' * A')'
143139
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})

0 commit comments

Comments
 (0)