@@ -198,7 +198,37 @@ else
198
198
m′, n′ = size (B, 1 ), size (B, 2 )
199
199
m == d || throw (DimensionMismatch (" right hand side has $m rows but D is $d by $d " ))
200
200
(m, n) == (m′, n′) || throw (DimensionMismatch (" expect output to be $m by $n , but got $m′ by $n′ " ))
201
- @. B = α * dd* A + β * B
201
+ @. B = α * dd * A + β * B
202
+
203
+ B
204
+ end
205
+
206
+ function LinearAlgebra. mul! (B:: AbstractGPUVecOrMat ,
207
+ A:: AbstractGPUVecOrMat ,
208
+ D:: Diagonal{<:Any, <:AbstractGPUArray} )
209
+ dd = D. diag
210
+ d = length (dd)
211
+ m, n = size (A, 1 ), size (A, 2 )
212
+ m′, n′ = size (B, 1 ), size (B, 2 )
213
+ n == d || throw (DimensionMismatch (" left hand side has $n columns but D is $d by $d " ))
214
+ (m, n) == (m′, n′) || throw (DimensionMismatch (" expect output to be $m by $n , but got $m′ by $n′ " ))
215
+ @. B' = dd * A'
216
+
217
+ B
218
+ end
219
+
220
+ function LinearAlgebra. mul! (B:: AbstractGPUVecOrMat ,
221
+ A:: AbstractGPUVecOrMat ,
222
+ D:: Diagonal{<:Any, <:AbstractGPUArray} ,
223
+ α:: Number ,
224
+ β:: Number )
225
+ dd = D. diag
226
+ d = length (dd)
227
+ m, n = size (A, 1 ), size (A, 2 )
228
+ m′, n′ = size (B, 1 ), size (B, 2 )
229
+ n == d || throw (DimensionMismatch (" left hand side has $n columns but D is $d by $d " ))
230
+ (m, n) == (m′, n′) || throw (DimensionMismatch (" expect output to be $m by $n , but got $m′ by $n′ " ))
231
+ @. B' = α * dd * A' + β * B'
202
232
203
233
B
204
234
end
0 commit comments