@@ -385,6 +385,40 @@ function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Adjoin
385
385
return C
386
386
end
387
387
388
+ # diagm
389
+
390
+ LinearAlgebra. diagm (kv:: Pair{<:Integer,<:CuVector} ...) = _cuda_diagm (nothing , kv... )
391
+ LinearAlgebra. diagm (m:: Integer , n:: Integer , kv:: Pair{<:Integer,<:CuVector} ...) = _cuda_diagm ((Int (m),Int (n)), kv... )
392
+ LinearAlgebra. diagm (v:: CuVector ) = LinearAlgebra. diagm (0 => v)
393
+ LinearAlgebra. diagm (m:: Integer , n:: Integer , v:: CuVector ) = LinearAlgebra. diagm (m, n, 0 => v)
394
+
395
+ function _cuda_diagm (size, kv:: Pair{<:Integer,<:CuVector} ...)
396
+ A = LinearAlgebra. diagm_container (size, kv... )
397
+ for p in kv
398
+ inds = LinearAlgebra. diagind (A, p. first)
399
+ copyto! (view (A, inds), p. second)
400
+ end
401
+ return A
402
+ end
403
+
404
+ function LinearAlgebra. diagm_container (size, kv:: Pair{<:Integer,<:CuVector} ...)
405
+ T = promote_type (map (x -> eltype (x. second), kv)... )
406
+ U = promote_type (T, typeof (zero (T)))
407
+ return cu (zeros (U, LinearAlgebra. diagm_size (size, kv... )... ))
408
+ end
409
+
410
+ function LinearAlgebra. diagm_size (size:: Nothing , kv:: Pair{<:Integer,<:CuVector} ...)
411
+ mnmax = mapreduce (x -> length (x. second) + abs (Int (x. first)), max, kv; init= 0 )
412
+ return mnmax, mnmax
413
+ end
414
+ function LinearAlgebra. diagm_size (size:: Tuple{Int,Int} , kv:: Pair{<:Integer,<:CuVector} ...)
415
+ mmax = mapreduce (x -> length (x. second) - min (0 ,Int (x. first)), max, kv; init= 0 )
416
+ nmax = mapreduce (x -> length (x. second) + max (0 ,Int (x. first)), max, kv; init= 0 )
417
+ m, n = size
418
+ (m ≥ mmax && n ≥ nmax) || throw (DimensionMismatch (lazy " invalid size=$size" ))
419
+ return m, n
420
+ end
421
+
388
422
# symmetric mul!
389
423
390
424
op_wrappers = ((identity, T -> ' N' , identity),
0 commit comments