Skip to content

Commit b53903d

Browse files
Avoid fallback in mul!(y, A::CompositeMap, x::AbstractMatrix) (#143)
1 parent 752a157 commit b53903d

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

src/composition.jl

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -129,41 +129,55 @@ LinearAlgebra.adjoint(A::CompositeMap{T}) where {T} =
129129
# comparison of CompositeMap objects
130130
Base.:(==)(A::CompositeMap, B::CompositeMap) = (eltype(A) == eltype(B) && A.maps == B.maps)
131131

132-
# multiplication with vectors
132+
# multiplication with vectors/matrices
133133
_unsafe_mul!(y::AbstractVecOrMat, A::CompositeMap, x::AbstractVector) =
134134
_compositemul!(y, A, x)
135+
_unsafe_mul!(y::AbstractMatrix, A::CompositeMap, x::AbstractMatrix) =
136+
_compositemul!(y, A, x)
137+
135138
function _compositemul!(y::AbstractVecOrMat,
136139
A::CompositeMap{<:Any,<:Tuple{LinearMap}},
137-
x::AbstractVector,
140+
x::AbstractVecOrMat,
138141
source = nothing,
139142
dest = nothing)
140143
return _unsafe_mul!(y, A.maps[1], x)
141144
end
142145
function _compositemul!(y::AbstractVecOrMat,
143-
A::CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}}, x::AbstractVector,
144-
source = similar(y, size(A.maps[1], 1)),
146+
A::CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}},
147+
x::AbstractVecOrMat,
148+
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)),
145149
dest = nothing)
146150
_unsafe_mul!(source, A.maps[1], x)
147151
_unsafe_mul!(y, A.maps[2], source)
148152
return y
149153
end
154+
155+
function _resize(dest::AbstractVector, sz::Tuple{<:Integer})
156+
try
157+
resize!(dest, sz[1])
158+
catch err
159+
if err == ErrorException("cannot resize array with shared data")
160+
dest = similar(dest, sz)
161+
else
162+
rethrow(err)
163+
end
164+
end
165+
dest
166+
end
167+
function _resize(dest::AbstractMatrix, sz::Tuple{<:Integer,<:Integer})
168+
size(dest) == sz && return dest
169+
similar(dest, sz)
170+
end
171+
150172
function _compositemul!(y::AbstractVecOrMat,
151173
A::CompositeMap,
152-
x::AbstractVector,
153-
source = similar(y, size(A.maps[1], 1)),
154-
dest = similar(y, size(A.maps[2], 1)))
174+
x::AbstractVecOrMat,
175+
source = similar(y, (size(A.maps[1],1), size(x)[2:end]...)),
176+
dest = similar(y, (size(A.maps[2],1), size(x)[2:end]...)))
155177
N = length(A.maps)
156178
_unsafe_mul!(source, A.maps[1], x)
157179
for n in 2:N-1
158-
try
159-
resize!(dest, size(A.maps[n], 1))
160-
catch err
161-
if err == ErrorException("cannot resize array with shared data")
162-
dest = similar(y, size(A.maps[n], 1))
163-
else
164-
rethrow(err)
165-
end
166-
end
180+
dest = _resize(dest, (size(A.maps[n],1), size(x)[2:end]...))
167181
_unsafe_mul!(dest, A.maps[n], source)
168182
dest, source = source, dest # alternate dest and source
169183
end

0 commit comments

Comments
 (0)