@@ -116,37 +116,44 @@ Base.:(*)(A₁::UniformScaling, A₂::LinearMap) = A₁.λ * A₂
116
116
LinearAlgebra. transpose (A:: CompositeMap{T} ) where {T} = CompositeMap {T} (map (transpose, reverse (A. maps)))
117
117
LinearAlgebra. adjoint (A:: CompositeMap{T} ) where {T} = CompositeMap {T} (map (adjoint, reverse (A. maps)))
118
118
119
- # comparison of LinearCombination objects
119
+ # comparison of CompositeMap objects
120
120
Base.:(== )(A:: CompositeMap , B:: CompositeMap ) = (eltype (A) == eltype (B) && A. maps == B. maps)
121
121
122
122
# multiplication with vectors
123
- function A_mul_B! (y:: AbstractVector , A:: CompositeMap , x:: AbstractVector )
123
+ function A_mul_B! (y:: AbstractVector , A:: CompositeMap{T,<:Tuple{LinearMap}} , x:: AbstractVector ) where {T}
124
+ return A_mul_B! (y, A. maps[1 ], x)
125
+ end
126
+ function A_mul_B! (y:: AbstractVector , A:: CompositeMap{T,<:Tuple{LinearMap,LinearMap}} , x:: AbstractVector ) where {T}
127
+ _compositemul! (y, A, x, similar (y, size (A. maps[1 ], 1 )))
128
+ end
129
+ function A_mul_B! (y:: AbstractVector , A:: CompositeMap{T,<:Tuple{Vararg{LinearMap}}} , x:: AbstractVector ) where {T}
130
+ _compositemul! (y, A, x, similar (y, size (A. maps[1 ], 1 )), similar (y, size (A. maps[2 ], 1 )))
131
+ end
132
+
133
+ function _compositemul! (y:: AbstractVector , A:: CompositeMap{T,<:Tuple{LinearMap,LinearMap}} , x:: AbstractVector , z:: AbstractVector ) where {T}
134
+ # no size checking, will be done by individual maps
135
+ A_mul_B! (z, A. maps[1 ], x)
136
+ A_mul_B! (y, A. maps[2 ], z)
137
+ return y
138
+ end
139
+ function _compositemul! (y:: AbstractVector , A:: CompositeMap{T,<:Tuple{Vararg{LinearMap}}} , x:: AbstractVector , source:: AbstractVector , dest:: AbstractVector ) where {T}
124
140
# no size checking, will be done by individual maps
125
141
N = length (A. maps)
126
- if N== 1
127
- A_mul_B! (y, A. maps[1 ], x)
128
- else
129
- dest = similar (y, size (A. maps[1 ], 1 ))
130
- A_mul_B! (dest, A. maps[1 ], x)
131
- source = dest
132
- if N> 2
133
- dest = similar (y, size (A. maps[2 ], 1 ))
134
- end
135
- for n in 2 : N- 1
136
- try
137
- resize! (dest, size (A. maps[n], 1 ))
138
- catch err
139
- if err == ErrorException (" cannot resize array with shared data" )
140
- dest = similar (y, size (A. maps[n], 1 ))
141
- else
142
- rethrow (err)
143
- end
142
+ A_mul_B! (source, A. maps[1 ], x)
143
+ for n in 2 : N- 1
144
+ try
145
+ resize! (dest, size (A. maps[n], 1 ))
146
+ catch err
147
+ if err == ErrorException (" cannot resize array with shared data" )
148
+ dest = similar (y, size (A. maps[n], 1 ))
149
+ else
150
+ rethrow (err)
144
151
end
145
- A_mul_B! (dest, A. maps[n], source)
146
- dest, source = source, dest # alternate dest and source
147
152
end
148
- A_mul_B! (y, A. maps[N], source)
153
+ A_mul_B! (dest, A. maps[n], source)
154
+ dest, source = source, dest # alternate dest and source
149
155
end
156
+ A_mul_B! (y, A. maps[N], source)
150
157
return y
151
158
end
152
159
0 commit comments