Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ wrapperop(_) = identity
wrapperop(::Adjoint) = adjoint
wrapperop(::Transpose) = transpose

# equivalent to wrapperop, but treats real transposes and adjoints identically
# this helps reduce compilation latencies, and also matches the behavior of `wrapper_char`
_wrapperop(x) = wrapperop(x)
_wrapperop(::Adjoint{<:Real}) = transpose

# the following fallbacks can be removed if Adjoint/Transpose are restricted to AbstractVecOrMat
size(A::AdjOrTrans) = reverse(size(A.parent))
axes(A::AdjOrTrans) = reverse(axes(A.parent))
Expand All @@ -341,8 +346,8 @@ IndexStyle(::Type{<:AdjOrTransAbsVec}) = IndexLinear()
@propagate_inbounds Base.isassigned(v::AdjOrTransAbsMat, i::Int, j::Int) = isassigned(v.parent, j, i)
@propagate_inbounds getindex(v::AdjOrTransAbsVec{T}, i::Int) where {T} = wrapperop(v)(v.parent[i-1+first(axes(v.parent)[1])])::T
@propagate_inbounds getindex(A::AdjOrTransAbsMat{T}, i::Int, j::Int) where {T} = wrapperop(A)(A.parent[j, i])::T
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrapperop(v)(x), i-1+first(axes(v.parent)[1])); v)
@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrapperop(A)(x), j, i); A)
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, _wrapperop(v)(x), i-1+first(axes(v.parent)[1])); v)
@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, _wrapperop(A)(x), j, i); A)
# AbstractArray interface, additional definitions to retain wrapper over vectors where appropriate
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrapperop(v)(v.parent[is])
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrapperop(v)(v.parent[:])
Expand Down
8 changes: 4 additions & 4 deletions src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,14 @@ function (*)(D::Diagonal, V::AbstractVector)
end

function mul(A::AdjOrTransAbsMat, D::Diagonal)
adj = wrapperop(A)
adj = _wrapperop(A)
copy(adj(adj(D) * adj(A)))
end
function mul(A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}, D::Diagonal{<:Number})
@invoke mul(A::AbstractMatrix, D::AbstractMatrix)
end
function mul(D::Diagonal, A::AdjOrTransAbsMat)
adj = wrapperop(A)
adj = _wrapperop(A)
copy(adj(adj(A) * adj(D)))
end
function mul(D::Diagonal{<:Number}, A::AdjOrTransAbsMat{<:Number, <:StridedMatrix})
Expand All @@ -358,7 +358,7 @@ end
# A' = A' * D => A = D' * A
# This uses the fact that D' is a Diagonal
function rmul!(A::AdjOrTransAbsMat, D::Diagonal)
f = wrapperop(A)
f = _wrapperop(A)
lmul!(f(D), f(A))
A
end
Expand Down Expand Up @@ -406,7 +406,7 @@ end
# A' = D * A' => A = A * D'
# This uses the fact that D' is a Diagonal
function lmul!(D::Diagonal, A::AdjOrTransAbsMat)
f = wrapperop(A)
f = _wrapperop(A)
rmul!(f(A), f(D))
A
end
Expand Down
14 changes: 7 additions & 7 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ matprod_dest(A, B, T) = similar(B, T, (size(A, 1), size(B, 2)))
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
_wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
_wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
end
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
_wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
_wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
end

# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
Expand All @@ -150,13 +150,13 @@ function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:Bla
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
convert(AbstractArray{TS}, A),
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
_wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
end
function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
TS = promote_type(eltype(A), eltype(B))
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
_wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
end
# the following case doesn't seem to benefit from the translation A*B = (B' * A')'
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})
Expand Down Expand Up @@ -1031,7 +1031,7 @@ end
function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta)
_rmul_or_fill!(C, beta)
(iszero(alpha) || isempty(A) || isempty(B)) && return C
t = wrapperop(A)
t = _wrapperop(A)
pB = parent(B)
pA = parent(A)
tmp = similar(C, axes(C, 2))
Expand Down
16 changes: 8 additions & 8 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1079,18 +1079,18 @@ _trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractTriangular) =
lmul!(A, copy!(C, B))
# redirect for UpperOrLowerTriangular
_trimul!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVector) =
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::AbstractMatrix) =
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
_trimul!(C::AbstractMatrix, A::AbstractMatrix, B::UpperOrLowerTriangular) =
generic_mattrimul!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
generic_mattrimul!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) =
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
# disambiguation with AbstractTriangular
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::AbstractTriangular) =
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::UpperOrLowerTriangular) =
generic_mattrimul!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
generic_mattrimul!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))

# methods for LinearAlgebra.jl's own triangular types, to avoid `istriu` checks
lmul!(A::UpperOrLowerTriangular, B::AbstractVecOrMat) = @inline _trimul!(B, A, B)
Expand Down Expand Up @@ -1142,9 +1142,9 @@ _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
rdiv!(copy!(C, A), B)
# redirect for UpperOrLowerTriangular to generic_*div!
_ldiv!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVecOrMat) =
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::UpperOrLowerTriangular) =
generic_mattridiv!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
generic_mattridiv!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))

function ldiv!(A::AbstractTriangular, B::AbstractVecOrMat)
if istriu(A)
Expand Down