Skip to content

Commit 78e6156

Browse files
authored
Treat real transposes like adjoint in internal dispatch (#1296)
Since real transposes are equivalent to adjoints, we may compile methods only for one of the two types when we are unwrapping a `Transpose` through the `wrapperop` mechanism. This will improve the time to the second execution in certain cases (as the same type will be re-used). This is primarily useful in internal method dispatches where the result of `wrapperop` will not be returned. For example, on master ```julia julia> using LinearAlgebra julia> U = UpperTriangular([1 2; 3 4]); julia> @time transpose(U) * parent(U); 0.140280 seconds (552.81 k allocations: 27.080 MiB, 99.90% compilation time) julia> @time adjoint(U) * parent(U); 0.124338 seconds (404.18 k allocations: 19.600 MiB, 99.92% compilation time) ``` whereas, on this PR, ```julia julia> @time transpose(U) * parent(U); 0.161032 seconds (553.25 k allocations: 27.090 MiB, 8.75% gc time, 99.91% compilation time) julia> @time adjoint(U) * parent(U); 0.053536 seconds (356.69 k allocations: 17.240 MiB, 99.81% compilation time) ``` The second execution is noticeably faster.
1 parent d21ad8c commit 78e6156

File tree

4 files changed

+26
-21
lines changed

4 files changed

+26
-21
lines changed

src/adjtrans.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,11 @@ wrapperop(_) = identity
327327
wrapperop(::Adjoint) = adjoint
328328
wrapperop(::Transpose) = transpose
329329

330+
# equivalent to wrapperop, but treats real transposes and adjoints identically
331+
# this helps reduce compilation latencies, and also matches the behavior of `wrapper_char`
332+
_wrapperop(x) = wrapperop(x)
333+
_wrapperop(::Adjoint{<:Real}) = transpose
334+
330335
# the following fallbacks can be removed if Adjoint/Transpose are restricted to AbstractVecOrMat
331336
size(A::AdjOrTrans) = reverse(size(A.parent))
332337
axes(A::AdjOrTrans) = reverse(axes(A.parent))
@@ -341,8 +346,8 @@ IndexStyle(::Type{<:AdjOrTransAbsVec}) = IndexLinear()
341346
@propagate_inbounds Base.isassigned(v::AdjOrTransAbsMat, i::Int, j::Int) = isassigned(v.parent, j, i)
342347
@propagate_inbounds getindex(v::AdjOrTransAbsVec{T}, i::Int) where {T} = wrapperop(v)(v.parent[i-1+first(axes(v.parent)[1])])::T
343348
@propagate_inbounds getindex(A::AdjOrTransAbsMat{T}, i::Int, j::Int) where {T} = wrapperop(A)(A.parent[j, i])::T
344-
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, wrapperop(v)(x), i-1+first(axes(v.parent)[1])); v)
345-
@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, wrapperop(A)(x), j, i); A)
349+
@propagate_inbounds setindex!(v::AdjOrTransAbsVec, x, i::Int) = (setindex!(v.parent, _wrapperop(v)(x), i-1+first(axes(v.parent)[1])); v)
350+
@propagate_inbounds setindex!(A::AdjOrTransAbsMat, x, i::Int, j::Int) = (setindex!(A.parent, _wrapperop(A)(x), j, i); A)
346351
# AbstractArray interface, additional definitions to retain wrapper over vectors where appropriate
347352
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, is::AbstractArray{Int}) = wrapperop(v)(v.parent[is])
348353
@propagate_inbounds getindex(v::AdjOrTransAbsVec, ::Colon, ::Colon) = wrapperop(v)(v.parent[:])

src/diagonal.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,14 @@ function (*)(D::Diagonal, V::AbstractVector)
333333
end
334334

335335
function mul(A::AdjOrTransAbsMat, D::Diagonal)
336-
adj = wrapperop(A)
336+
adj = _wrapperop(A)
337337
copy(adj(adj(D) * adj(A)))
338338
end
339339
function mul(A::AdjOrTransAbsMat{<:Number, <:StridedMatrix}, D::Diagonal{<:Number})
340340
@invoke mul(A::AbstractMatrix, D::AbstractMatrix)
341341
end
342342
function mul(D::Diagonal, A::AdjOrTransAbsMat)
343-
adj = wrapperop(A)
343+
adj = _wrapperop(A)
344344
copy(adj(adj(A) * adj(D)))
345345
end
346346
function mul(D::Diagonal{<:Number}, A::AdjOrTransAbsMat{<:Number, <:StridedMatrix})
@@ -358,7 +358,7 @@ end
358358
# A' = A' * D => A = D' * A
359359
# This uses the fact that D' is a Diagonal
360360
function rmul!(A::AdjOrTransAbsMat, D::Diagonal)
361-
f = wrapperop(A)
361+
f = _wrapperop(A)
362362
lmul!(f(D), f(A))
363363
A
364364
end
@@ -406,7 +406,7 @@ end
406406
# A' = D * A' => A = A * D'
407407
# This uses the fact that D' is a Diagonal
408408
function lmul!(D::Diagonal, A::AdjOrTransAbsMat)
409-
f = wrapperop(A)
409+
f = _wrapperop(A)
410410
rmul!(f(A), f(D))
411411
A
412412
end

src/matmul.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,14 @@ matprod_dest(A, B, T) = similar(B, T, (size(A, 1), size(B, 2)))
134134
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasReal}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
135135
TS = promote_type(eltype(A), eltype(B))
136136
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
137-
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
138-
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
137+
_wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
138+
_wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
139139
end
140140
function (*)(A::StridedMaybeAdjOrTransMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasComplex})
141141
TS = promote_type(eltype(A), eltype(B))
142142
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
143-
wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
144-
wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
143+
_wrapperop(A)(convert(AbstractArray{TS}, _unwrap(A))),
144+
_wrapperop(B)(convert(AbstractArray{TS}, _unwrap(B))))
145145
end
146146

147147
# Complex Matrix times real matrix: We use that it is generally faster to reinterpret the
@@ -150,13 +150,13 @@ function (*)(A::StridedMatrix{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:Bla
150150
TS = promote_type(eltype(A), eltype(B))
151151
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
152152
convert(AbstractArray{TS}, A),
153-
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
153+
_wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
154154
end
155155
function (*)(A::AdjOrTransStridedMat{<:BlasComplex}, B::StridedMaybeAdjOrTransMat{<:BlasReal})
156156
TS = promote_type(eltype(A), eltype(B))
157157
mul!(similar(B, TS, (size(A, 1), size(B, 2))),
158158
copymutable_oftype(A, TS), # remove AdjOrTrans to use reinterpret trick below
159-
wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
159+
_wrapperop(B)(convert(AbstractArray{real(TS)}, _unwrap(B))))
160160
end
161161
# the following case doesn't seem to benefit from the translation A*B = (B' * A')'
162162
function (*)(A::StridedMatrix{<:BlasReal}, B::StridedMatrix{<:BlasComplex})
@@ -1031,7 +1031,7 @@ end
10311031
function _generic_matmatmul_adjtrans!(C, A, B, alpha, beta)
10321032
_rmul_or_fill!(C, beta)
10331033
(iszero(alpha) || isempty(A) || isempty(B)) && return C
1034-
t = wrapperop(A)
1034+
t = _wrapperop(A)
10351035
pB = parent(B)
10361036
pA = parent(A)
10371037
tmp = similar(C, axes(C, 2))

src/triangular.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,18 +1105,18 @@ _trimul!(C::AbstractMatrix, A::AbstractTriangular, B::AbstractTriangular) =
11051105
lmul!(A, copy!(C, B))
11061106
# redirect for UpperOrLowerTriangular
11071107
_trimul!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVector) =
1108-
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1108+
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
11091109
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::AbstractMatrix) =
1110-
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1110+
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
11111111
_trimul!(C::AbstractMatrix, A::AbstractMatrix, B::UpperOrLowerTriangular) =
1112-
generic_mattrimul!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
1112+
generic_mattrimul!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))
11131113
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::UpperOrLowerTriangular) =
1114-
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1114+
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
11151115
# disambiguation with AbstractTriangular
11161116
_trimul!(C::AbstractMatrix, A::UpperOrLowerTriangular, B::AbstractTriangular) =
1117-
generic_trimatmul!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1117+
generic_trimatmul!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
11181118
_trimul!(C::AbstractMatrix, A::AbstractTriangular, B::UpperOrLowerTriangular) =
1119-
generic_mattrimul!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
1119+
generic_mattrimul!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))
11201120

11211121
# methods for LinearAlgebra.jl's own triangular types, to avoid `istriu` checks
11221122
lmul!(A::UpperOrLowerTriangular, B::AbstractVecOrMat) = @inline _trimul!(B, A, B)
@@ -1168,9 +1168,9 @@ _rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractTriangular) =
11681168
rdiv!(copy!(C, A), B)
11691169
# redirect for UpperOrLowerTriangular to generic_*div!
11701170
_ldiv!(C::AbstractVecOrMat, A::UpperOrLowerTriangular, B::AbstractVecOrMat) =
1171-
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), wrapperop(parent(A)), _unwrap_at(parent(A)), B)
1171+
generic_trimatdiv!(C, uplo_char(A), isunit_char(A), _wrapperop(parent(A)), _unwrap_at(parent(A)), B)
11721172
_rdiv!(C::AbstractMatrix, A::AbstractMatrix, B::UpperOrLowerTriangular) =
1173-
generic_mattridiv!(C, uplo_char(B), isunit_char(B), wrapperop(parent(B)), A, _unwrap_at(parent(B)))
1173+
generic_mattridiv!(C, uplo_char(B), isunit_char(B), _wrapperop(parent(B)), A, _unwrap_at(parent(B)))
11741174

11751175
function ldiv!(A::AbstractTriangular, B::AbstractVecOrMat)
11761176
if istriu(A)

0 commit comments

Comments
 (0)