Skip to content

Commit 625bc62

Browse files
committed
Treat real transposes like adjoint in internal dispatch
1 parent 07725da commit 625bc62

File tree

4 files changed

+24
-19
lines changed

4 files changed

+24
-19
lines changed

src/adjtrans.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ wrapperop(_) = identity
327327
wrapperop(::Adjoint) = adjoint
328328
wrapperop(::Transpose) = transpose
329329

330+
# equivalent to wrapperop, but treats real transposes and adjoints identically
331+
# this might help reduce compilation times
332+
_wrapperop(x) = wrapperop(x)
333+
_wrapperop(::Transpose{<:Real}) = adjoint
334+
330335
# the following fallbacks can be removed if Adjoint/Transpose are restricted to AbstractVecOrMat
331336
size(A::AdjOrTrans) = reverse(size(A.parent))
332337
axes(A::AdjOrTrans) = reverse(axes(A.parent))

src/diagonal.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,14 @@ function (*)(D::Diagonal, V::AbstractVector)
333333
end
334334

335335
function mul(A::AdjOrTransAbsMat, D::Diagonal)
336-
adj = wrapperop(A)
336+
adj = _wrapperop(A)
337337
copy(adj(adj(D) * adj(A)))
338338
end
339339
function mul(A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}, D::Diagonal{<:Number})
340340
@invoke mul(A::AbstractMatrix, D::AbstractMatrix)
341341
end
342342
function mul(D::Diagonal, A::AdjOrTransAbsMat)
343-
adj = wrapperop(A)
343+
adj = _wrapperop(A)
344344
copy(adj(adj(A) * adj(D)))
345345
end
346346
function mul(D::Diagonal{<:Number}, A::AdjOrTransAbsMat{<:Number, <:StridedMatrix})
@@ -358,7 +358,7 @@ end
358358
# A' = A' * D => A = D' * A
359359
# This uses the fact that D' is a Diagonal
360360
function rmul!(A::AdjOrTransAbsMat, D::Diagonal)
361-
f = wrapperop(A)
361+
f = _wrapperop(A)
362362
lmul!(f(D), f(A))
363363
A
364364
end
@@ -406,7 +406,7 @@ end
406406
# A' = D * A' => A = A * D'
407407
# This uses the fact that D' is a Diagonal
408408
function lmul!(D::Diagonal, A::AdjOrTransAbsMat)
409-
f = wrapperop(A)
409+
f = _wrapperop(A)
410410
rmul!(f(A), f(D))
411411
A
412412
end

src/matmul.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,14 @@ matprod_dest(A, B, T) = similar(B, T, (size(A, 1), size(B, 2)))
134134
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
135135
TS = promote_type(eltype(A), eltype(B))
136136
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
137-
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
138-
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
137+
_wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
138+
_wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
139139
end
140140
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
141141
TS = promote_type(eltype(A), eltype(B))
142142
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
143-
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
144-
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
143+
_wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
144+
_wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
145145
end
146146

147147
# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
@@ -150,13 +150,13 @@ function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:Bla
150150
TS = promote_type(eltype(A), eltype(B))
151151
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
152152
convert(AbstractArray{TS}, A),
153-
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
153+
_wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
154154
end
155155
function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
156156
TS = promote_type(eltype(A), eltype(B))
157157
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
158158
copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
159-
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
159+
_wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
160160
end
161161
# the following case doesn't seem to benefit from the translation A*B = (B' * A')'
162162
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})
@@ -1031,7 +1031,7 @@ end
10311031
function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta)
10321032
_rmul_or_fill!(C, beta)
10331033
(iszero(alpha) || isempty(A) || isempty(B)) && return C
1034-
t = wrapperop(A)
1034+
t = _wrapperop(A)
10351035
pB = parent(B)
10361036
pA = parent(A)
10371037
tmp = similar(C, axes(C, 2))

src/triangular.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,18 +1079,18 @@ _trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractTriangular) =
10791079
lmul!(A, copy!(C, B))
10801080
# redirect for UpperOrLowerTriangular
10811081
_trimul!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVector) =
1082-
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1082+
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
10831083
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::AbstractMatrix) =
1084-
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1084+
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
10851085
_trimul!(C::AbstractMatrix, A::AbstractMatrix, B::UpperOrLowerTriangular) =
1086-
generic_mattrimul!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
1086+
generic_mattrimul!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))
10871087
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) =
1088-
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1088+
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
10891089
# disambiguation with AbstractTriangular
10901090
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::AbstractTriangular) =
1091-
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1091+
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
10921092
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::UpperOrLowerTriangular) =
1093-
generic_mattrimul!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
1093+
generic_mattrimul!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))
10941094

10951095
# methods for LinearAlgebra.jl's own triangular types, to avoid `istriu` checks
10961096
lmul!(A::UpperOrLowerTriangular, B::AbstractVecOrMat) = @inline _trimul!(B, A, B)
@@ -1142,9 +1142,9 @@ _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
11421142
rdiv!(copy!(C, A), B)
11431143
# redirect for UpperOrLowerTriangular to generic_*div!
11441144
_ldiv!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVecOrMat) =
1145-
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1145+
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
11461146
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::UpperOrLowerTriangular) =
1147-
generic_mattridiv!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
1147+
generic_mattridiv!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))
11481148

11491149
function ldiv!(A::AbstractTriangular, B::AbstractVecOrMat)
11501150
if istriu(A)

0 commit comments

Comments
 (0)