diff --git a/src/adjtrans.jl b/src/adjtrans.jl index d81aa3ae..a2fc3393 100644 --- a/src/adjtrans.jl +++ b/src/adjtrans.jl @@ -327,6 +327,11 @@ wrapperop(_) = identity wrapperop(::Adjoint) = adjoint wrapperop(::Transpose) = transpose +# equivalent to wrapperop, but treats real transposes and adjoints identically +# this helps reduce compilation latencies, and also matches the behavior of `wrapper_char` +_wrapperop(x) = wrapperop(x) +_wrapperop(::Adjoint{<:Real}) = transpose + # the following fallbacks can be removed if Adjoint/Transpose are restricted to AbstractVecOrMat size(A::AdjOrTrans) = reverse(size(A.parent)) axes(A::AdjOrTrans) = reverse(axes(A.parent)) @@ -341,8 +346,8 @@ IndexStyle(::Type{<:AdjOrTransAbsVec}) = IndexLinear() @propagate_inbounds Base.isassigned(v::AdjOrTransAbsMat, i::Int, j::Int) = isassigned(v.parent, j, i) @propagate_inbounds getindex(v::AdjOrTransAbsVec{T}, i::Int) where {T} = wrapperop(v)(v.parent[i-1+first(axes(v.parent)[1])])::T @propagate_inbounds getindex(A::AdjOrTransAbsMat{T}, i::Int, j::Int) where {T} = wrapperop(A)(A.parent[j, i])::T -@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrapperop(v)(x), i-1+first(axes(v.parent)[1])); v) -@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrapperop(A)(x), j, i); A) +@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, _wrapperop(v)(x), i-1+first(axes(v.parent)[1])); v) +@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, _wrapperop(A)(x), j, i); A) # AbstractArray interface, additional definitions to retain wrapper over vectors where appropriate @propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrapperop(v)(v.parent[is]) @propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrapperop(v)(v.parent[:]) diff --git a/src/diagonal.jl b/src/diagonal.jl index 5cad3326..fa402ac6 100644 --- a/src/diagonal.jl +++ b/src/diagonal.jl @@ -333,14 +333,14 @@ function (*)(D::Diagonal, V::AbstractVector) end function mul(A::AdjOrTransAbsMat, D::Diagonal) - adj = wrapperop(A) + adj = _wrapperop(A) copy(adj(adj(D) * adj(A))) end function mul(A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}, D::Diagonal{<:Number}) @invoke mul(A::AbstractMatrix, D::AbstractMatrix) end function mul(D::Diagonal, A::AdjOrTransAbsMat) - adj = wrapperop(A) + adj = _wrapperop(A) copy(adj(adj(A) * adj(D))) end function mul(D::Diagonal{<:Number}, A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}) @@ -358,7 +358,7 @@ end # A' = A' * D => A = D' * A # This uses the fact that D' is a Diagonal function rmul!(A::AdjOrTransAbsMat, D::Diagonal) - f = wrapperop(A) + f = _wrapperop(A) lmul!(f(D), f(A)) A end @@ -406,7 +406,7 @@ end # A' = D * A' => A = A * D' # This uses the fact that D' is a Diagonal function lmul!(D::Diagonal, A::AdjOrTransAbsMat) - f = wrapperop(A) + f = _wrapperop(A) rmul!(f(A), f(D)) A end diff --git a/src/matmul.jl b/src/matmul.jl index 202ef763..5f035172 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -134,14 +134,14 @@ matprod_dest(A, B, T) = similar(B, T, (size(A, 1), size(B, 2))) function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal}) TS = promote_type(eltype(A), eltype(B)) mul!(similar(B, TS, (size(A, 1), size(B, 2))), - wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))), - wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B)))) + _wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))), + _wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B)))) end function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex}) TS = promote_type(eltype(A), eltype(B)) mul!(similar(B, TS, (size(A, 1), size(B, 2))), - wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))), - wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B)))) + _wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))), + _wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B)))) end # Complex Matrix times real matrix: We use that it is generally faster to reinterpret the @@ -150,13 +150,13 @@ function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:Bla TS = promote_type(eltype(A), eltype(B)) mul!(similar(B, TS, (size(A, 1), size(B, 2))), convert(AbstractArray{TS}, A), - wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B)))) + _wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B)))) end function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal}) TS = promote_type(eltype(A), eltype(B)) mul!(similar(B, TS, (size(A, 1), size(B, 2))), copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below - wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B)))) + _wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B)))) end # the following case doesn't seem to benefit from the translation A*B = (B' * A')' function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex}) @@ -1031,7 +1031,7 @@ end function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta) _rmul_or_fill!(C, beta) (iszero(alpha) || isempty(A) || isempty(B)) && return C - t = wrapperop(A) + t = _wrapperop(A) pB = parent(B) pA = parent(A) tmp = similar(C, axes(C, 2)) diff --git a/src/triangular.jl b/src/triangular.jl index ad66f492..75035437 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -1079,18 +1079,18 @@ _trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractTriangular) = lmul!(A, copy!(C, B)) # redirect for UpperOrLowerTriangular _trimul!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVector) = - generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B) + generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B) _trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::AbstractMatrix) = - generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B) + generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B) _trimul!(C::AbstractMatrix, A::AbstractMatrix, B::UpperOrLowerTriangular) = - generic_mattrimul!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B))) + generic_mattrimul!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B))) _trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) = - generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B) + generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B) # disambiguation with AbstractTriangular _trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::AbstractTriangular) = - generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B) + generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B) _trimul!(C::AbstractMatrix, A::AbstractTriangular, B::UpperOrLowerTriangular) = - generic_mattrimul!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B))) + generic_mattrimul!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B))) # methods for LinearAlgebra.jl's own triangular types, to avoid `istriu` checks lmul!(A::UpperOrLowerTriangular, B::AbstractVecOrMat) = @inline _trimul!(B, A, B) @@ -1142,9 +1142,9 @@ _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) = rdiv!(copy!(C, A), B) # redirect for UpperOrLowerTriangular to generic_*div! _ldiv!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVecOrMat) = - generic_trimatdiv!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B) + generic_trimatdiv!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B) _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::UpperOrLowerTriangular) = - generic_mattridiv!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B))) + generic_mattridiv!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B))) function ldiv!(A::AbstractTriangular, B::AbstractVecOrMat) if istriu(A)