@@ -129,41 +129,55 @@ LinearAlgebra.adjoint(A::CompositeMap{T}) where {T} =
129
129
# comparison of CompositeMap objects
130
130
Base.:(== )(A:: CompositeMap , B:: CompositeMap ) = (eltype (A) == eltype (B) && A. maps == B. maps)
131
131
132
- # multiplication with vectors
132
+ # multiplication with vectors/matrices
133
133
_unsafe_mul! (y:: AbstractVecOrMat , A:: CompositeMap , x:: AbstractVector ) =
134
134
_compositemul! (y, A, x)
135
+ _unsafe_mul! (y:: AbstractMatrix , A:: CompositeMap , x:: AbstractMatrix ) =
136
+ _compositemul! (y, A, x)
137
+
135
138
function _compositemul! (y:: AbstractVecOrMat ,
136
139
A:: CompositeMap{<:Any,<:Tuple{LinearMap}} ,
137
- x:: AbstractVector ,
140
+ x:: AbstractVecOrMat ,
138
141
source = nothing ,
139
142
dest = nothing )
140
143
return _unsafe_mul! (y, A. maps[1 ], x)
141
144
end
142
145
function _compositemul! (y:: AbstractVecOrMat ,
143
- A:: CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}} , x:: AbstractVector ,
144
- source = similar (y, size (A. maps[1 ], 1 )),
146
+ A:: CompositeMap{<:Any,<:Tuple{LinearMap,LinearMap}} ,
147
+ x:: AbstractVecOrMat ,
148
+ source = similar (y, (size (A. maps[1 ],1 ), size (x)[2 : end ]. .. )),
145
149
dest = nothing )
146
150
_unsafe_mul! (source, A. maps[1 ], x)
147
151
_unsafe_mul! (y, A. maps[2 ], source)
148
152
return y
149
153
end
154
+
155
+ function _resize (dest:: AbstractVector , sz:: Tuple{<:Integer} )
156
+ try
157
+ resize! (dest, sz[1 ])
158
+ catch err
159
+ if err == ErrorException (" cannot resize array with shared data" )
160
+ dest = similar (dest, sz)
161
+ else
162
+ rethrow (err)
163
+ end
164
+ end
165
+ dest
166
+ end
167
+ function _resize (dest:: AbstractMatrix , sz:: Tuple{<:Integer,<:Integer} )
168
+ size (dest) == sz && return dest
169
+ similar (dest, sz)
170
+ end
171
+
150
172
function _compositemul! (y:: AbstractVecOrMat ,
151
173
A:: CompositeMap ,
152
- x:: AbstractVector ,
153
- source = similar (y, size (A. maps[1 ], 1 )),
154
- dest = similar (y, size (A. maps[2 ], 1 )))
174
+ x:: AbstractVecOrMat ,
175
+ source = similar (y, ( size (A. maps[1 ],1 ), size (x)[ 2 : end ] . .. )),
176
+ dest = similar (y, ( size (A. maps[2 ],1 ), size (x)[ 2 : end ] . .. )))
155
177
N = length (A. maps)
156
178
_unsafe_mul! (source, A. maps[1 ], x)
157
179
for n in 2 : N- 1
158
- try
159
- resize! (dest, size (A. maps[n], 1 ))
160
- catch err
161
- if err == ErrorException (" cannot resize array with shared data" )
162
- dest = similar (y, size (A. maps[n], 1 ))
163
- else
164
- rethrow (err)
165
- end
166
- end
180
+ dest = _resize (dest, (size (A. maps[n],1 ), size (x)[2 : end ]. .. ))
167
181
_unsafe_mul! (dest, A. maps[n], source)
168
182
dest, source = source, dest # alternate dest and source
169
183
end
0 commit comments