Skip to content

Commit 1004fd9

Browse files
authored
CUBLAS: Add support for diagm (#2786)
1 parent 22a875b commit 1004fd9

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

lib/cublas/linalg.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,40 @@ function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Adjoin
385385
return C
386386
end
387387

388+
# diagm
389+
390+
LinearAlgebra.diagm(kv::Pair{<:Integer,<:CuVector}...) = _cuda_diagm(nothing, kv...)
391+
LinearAlgebra.diagm(m::Integer, n::Integer, kv::Pair{<:Integer,<:CuVector}...) = _cuda_diagm((Int(m),Int(n)), kv...)
392+
LinearAlgebra.diagm(v::CuVector) = LinearAlgebra.diagm(0 => v)
393+
LinearAlgebra.diagm(m::Integer, n::Integer, v::CuVector) = LinearAlgebra.diagm(m, n, 0 => v)
394+
395+
function _cuda_diagm(size, kv::Pair{<:Integer,<:CuVector}...)
396+
A = LinearAlgebra.diagm_container(size, kv...)
397+
for p in kv
398+
inds = LinearAlgebra.diagind(A, p.first)
399+
copyto!(view(A, inds), p.second)
400+
end
401+
return A
402+
end
403+
404+
function LinearAlgebra.diagm_container(size, kv::Pair{<:Integer,<:CuVector}...)
405+
T = promote_type(map(x -> eltype(x.second), kv)...)
406+
U = promote_type(T, typeof(zero(T)))
407+
return cu(zeros(U, LinearAlgebra.diagm_size(size, kv...)...))
408+
end
409+
410+
function LinearAlgebra.diagm_size(size::Nothing, kv::Pair{<:Integer,<:CuVector}...)
411+
mnmax = mapreduce(x -> length(x.second) + abs(Int(x.first)), max, kv; init=0)
412+
return mnmax, mnmax
413+
end
414+
function LinearAlgebra.diagm_size(size::Tuple{Int,Int}, kv::Pair{<:Integer,<:CuVector}...)
415+
mmax = mapreduce(x -> length(x.second) - min(0,Int(x.first)), max, kv; init=0)
416+
nmax = mapreduce(x -> length(x.second) + max(0,Int(x.first)), max, kv; init=0)
417+
m, n = size
418+
(m mmax && n nmax) || throw(DimensionMismatch(lazy"invalid size=$size"))
419+
return m, n
420+
end
421+
388422
# symmetric mul!
389423

390424
op_wrappers = ((identity, T -> 'N', identity),

test/libraries/cublas/level3.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,23 @@ k = 13
467467
h_C = triu(h_C)
468468
@test C h_C
469469
end
470+
@testset "diagm" begin
471+
A = rand(elty, m)
472+
B = rand(elty, n)
473+
# move to device
474+
d_A = CuArray(A)
475+
d_B = CuArray(B)
476+
diagA = diagm(d_A)
477+
diagB = diagm(2 => d_B)
478+
# move back to host and compare
479+
diagind_A = diagind(diagA, 0)
480+
diagind_B = diagind(diagB, 2)
481+
h_A = Array(diagA[diagind_A])
482+
h_B = Array(diagB[diagind_B])
483+
484+
@test A h_A
485+
@test B h_B
486+
end
470487
if elty <: Complex
471488
@testset "herk!" begin
472489
alpha = rand(real(elty))

0 commit comments

Comments
 (0)