Skip to content

Commit 2486af3

Browse files
authored
Added multiplication of transpose / adjoint matrices by diagonal matrices (#2518)
1 parent 22da046 commit 2486af3

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

lib/cublas/linalg.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,22 @@ function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::CuMatr
338338
return dgmm!('L', B, A.diag, C)
339339
end
340340

341+
function LinearAlgebra.mul!(C::CuMatrix{T}, A::Transpose{T,<:CuMatrix}, B::Diagonal{T,<:CuVector}) where {T<:CublasFloat}
342+
return dgmm!('R', CuMatrix(A), B.diag, C)
343+
end
344+
345+
function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Transpose{T,<:CuMatrix}) where {T<:CublasFloat}
346+
return dgmm!('L', CuMatrix(B), A.diag, C)
347+
end
348+
349+
function LinearAlgebra.mul!(C::CuMatrix{T}, A::Adjoint{T,<:CuMatrix}, B::Diagonal{T,<:CuVector}) where {T<:CublasFloat}
350+
return dgmm!('R', CuMatrix(A), B.diag, C)
351+
end
352+
353+
function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Adjoint{T,<:CuMatrix}) where {T<:CublasFloat}
354+
return dgmm!('L', CuMatrix(B), A.diag, C)
355+
end
356+
341357
# symmetric mul!
342358

343359
op_wrappers = ((identity, T -> 'N', identity),

test/libraries/cublas.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,6 +2293,30 @@ end
22932293
d_Y = Diagonal(d_y)
22942294
mul!(d_AY, d_A, d_Y)
22952295
Array(d_AY) A * Diagonal(y)
2296+
2297+
YA = rand(elty,n,m)
2298+
d_YA = CuArray(YA)
2299+
d_Y = Diagonal(d_y)
2300+
mul!(d_YA, d_Y, transpose(d_A))
2301+
Array(d_YA) Diagonal(y) * transpose(A)
2302+
2303+
AX = rand(elty,n,m)
2304+
d_AX = CuArray(AX)
2305+
d_X = Diagonal(d_x)
2306+
mul!(d_AX, transpose(d_A), d_X)
2307+
Array(d_AX) transpose(A) * Diagonal(x)
2308+
2309+
YA = rand(elty,n,m)
2310+
d_YA = CuArray(YA)
2311+
d_Y = Diagonal(d_y)
2312+
mul!(d_YA, d_Y, d_A')
2313+
Array(d_YA) Diagonal(y) * A'
2314+
2315+
AX = rand(elty,n,m)
2316+
d_AX = CuArray(AX)
2317+
d_X = Diagonal(d_x)
2318+
mul!(d_AX, d_A', d_X)
2319+
Array(d_AX) A' * Diagonal(x)
22962320
end
22972321
end # extensions
22982322

0 commit comments

Comments
 (0)