Skip to content

Commit f81cdf7

Browse files
authored
Wrap CUBLAS.spmv and spr (#1248)
1 parent 4e6eb22 commit f81cdf7

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

lib/cublas/wrappers.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,32 @@ for (fname, elty) in ((:cublasDgbmv_v2,:Float64),
415415
end
416416
end
417417

418+
### spmv
419+
for (fname, elty) in ((:cublasDspmv_v2,:Float64),
420+
(:cublasSspmv_v2,:Float32))
421+
@eval begin
422+
function spmv!(uplo::Char,
423+
alpha::Number,
424+
AP::StridedCuVector{$elty},
425+
x::StridedCuVector{$elty},
426+
beta::Number,
427+
y::StridedCuVector{$elty})
428+
n = round(Int, (sqrt(8*length(AP))-1)/2)
429+
if n != length(x) || n != length(y) throw(DimensionMismatch("")) end
430+
incx = stride(x,1)
431+
incy = stride(y,1)
432+
$fname(handle(), uplo, n, alpha, AP, x, incx, beta, y, incy)
433+
y
434+
end
435+
function spmv(uplo::Char, alpha::Number, AP::StridedCuVector{$elty}, x::StridedCuVector{$elty})
436+
spmv!(uplo, alpha, AP, x, zero($elty), similar(x))
437+
end
438+
function spmv(uplo::Char, AP::StridedCuVector{$elty}, x::StridedCuVector{$elty})
439+
spmv(uplo, one($elty), AP, x)
440+
end
441+
end
442+
end
443+
418444
### symv
419445
for (fname, elty) in ((:cublasDsymv_v2,:Float64),
420446
(:cublasSsymv_v2,:Float32),
@@ -695,6 +721,23 @@ for (fname, elty) in ((:cublasDger_v2,:Float64),
695721
end
696722
end
697723

724+
### spr
725+
for (fname, elty) in ((:cublasDspr_v2,:Float64),
726+
(:cublasSspr_v2,:Float32))
727+
@eval begin
728+
function spr!(uplo::Char,
729+
alpha::Number,
730+
x::StridedCuVector{$elty},
731+
AP::StridedCuVector{$elty})
732+
n = round(Int, (sqrt(8*length(AP))-1)/2)
733+
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
734+
incx = stride(x,1)
735+
$fname(handle(), uplo, n, alpha, x, incx, AP)
736+
AP
737+
end
738+
end
739+
end
740+
698741
### syr
699742
# TODO: check calls in julia b/c blas may not define syr for Z and C
700743
for (fname, elty) in ((:cublasDsyr_v2,:Float64),

test/cublas.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,69 @@ end
298298
dhA = CuArray(hA)
299299
x = rand(elty,m)
300300
dx = CuArray(x)
301+
302+
function pack(A, uplo)
303+
AP = Vector{elty}(undef, (n*(n+1))>>1)
304+
k = 1
305+
for j in 1:n
306+
for i in (uplo==:L ? (j:n) : (1:j))
307+
AP[k] = A[i,j]
308+
k += 1
309+
end
310+
end
311+
return AP
312+
end
313+
314+
if elty in ["Float32", "Float64"]
315+
# pack matrices
316+
sAPU = pack(sA, :U)
317+
dsAPU = CuVector(sAPU)
318+
sAPL = pack(sA, :L)
319+
dsAPL = CuVector(sAPL)
320+
321+
@testset "spmv!" begin
322+
# generate vectors
323+
y = rand(elty,m)
324+
# copy to device
325+
dy = CuArray(y)
326+
# execute on host
327+
BLAS.spmv!('U',alpha,sAPU,x,beta,y)
328+
# execute on device
329+
CUBLAS.spmv!('U',alpha,dsAPU,dx,beta,dy)
330+
# compare results
331+
hy = Array(dy)
332+
@test y hy
333+
# execute on host
334+
BLAS.spmv!('U',alpha,sAPL,x,beta,y)
335+
# execute on device
336+
CUBLAS.spmv!('U',alpha,dsAPL,dx,beta,dy)
337+
# compare results
338+
hy = Array(dy)
339+
@test y hy
340+
end
341+
342+
@testset "spr!" begin
343+
# execute on host
344+
VERSION>=v"1.8.0-DEV.1049" && BLAS.spr!('U',alpha,x,sAPU)
345+
# execute on device
346+
CUBLAS.spr!('U',alpha,dx,dsAPU)
347+
# compare results
348+
if VERSION>=v"1.8.0-DEV.1049"
349+
hsAPU = Array(dsAPU)
350+
@test sAPU hsAPU
351+
end
352+
# execute on host
353+
VERSION>=v"1.8.0-DEV.1049" && BLAS.spr!('U',alpha,x,sAPL)
354+
# execute on device
355+
CUBLAS.spr!('U',alpha,dx,dsAPL)
356+
# compare results
357+
if VERSION>=v"1.8.0-DEV.1049"
358+
hAPL = Array(dAPL)
359+
@test sAPL hAPL
360+
end
361+
end
362+
end
363+
301364
@testset "symv!" begin
302365
# generate vectors
303366
y = rand(elty,m)

0 commit comments

Comments
 (0)