Skip to content

Commit 98e320c

Browse files
authored
Removed allocations for transpose/adjoint - diagonal multiplications (#2538)
1 parent 7d23bbf commit 98e320c

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

lib/cublas/linalg.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,19 +347,27 @@ function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::CuMatr
347347
end
348348

349349
function LinearAlgebra.mul!(C::CuMatrix{T}, A::Transpose{T,<:CuMatrix}, B::Diagonal{T,<:CuVector}) where {T<:CublasFloat}
350-
return dgmm!('R', CuMatrix(A), B.diag, C)
350+
C .= A
351+
C .*= transpose(B.diag)
352+
return C
351353
end
352354

353355
function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Transpose{T,<:CuMatrix}) where {T<:CublasFloat}
354-
return dgmm!('L', CuMatrix(B), A.diag, C)
356+
C .= B
357+
C .*= A.diag
358+
return C
355359
end
356360

357361
function LinearAlgebra.mul!(C::CuMatrix{T}, A::Adjoint{T,<:CuMatrix}, B::Diagonal{T,<:CuVector}) where {T<:CublasFloat}
358-
return dgmm!('R', CuMatrix(A), B.diag, C)
362+
C .= A
363+
C .*= transpose(B.diag)
364+
return C
359365
end
360366

361367
function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Adjoint{T,<:CuMatrix}) where {T<:CublasFloat}
362-
return dgmm!('L', CuMatrix(B), A.diag, C)
368+
C .= B
369+
C .*= A.diag
370+
return C
363371
end
364372

365373
# symmetric mul!

0 commit comments

Comments
 (0)