Skip to content

Commit 22da046

Browse files
authored
[CUSOLVER] Interface Xgeev! (#2513)
1 parent 56eb51b commit 22da046

File tree

2 files changed

+85
-13
lines changed

2 files changed

+85
-13
lines changed

lib/cusolver/dense_generic.jl

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
3434
end
3535

3636
flag = @allowscalar dh.info[1]
37-
chkargsok(BlasInt(flag))
37+
chkargsok(flag |> BlasInt)
3838
A, flag
3939
end
4040

@@ -52,7 +52,7 @@ function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) whe
5252
cusolverDnXpotrs(dh, params, uplo, n, nrhs, T, A, lda, T, B, ldb, dh.info)
5353

5454
flag = @allowscalar dh.info[1]
55-
chkargsok(BlasInt(flag))
55+
chkargsok(flag |> BlasInt)
5656
B
5757
end
5858

@@ -77,7 +77,7 @@ function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{Int64}) where {T <: BlasF
7777
end
7878

7979
flag = @allowscalar dh.info[1]
80-
chkargsok(BlasInt(flag))
80+
chkargsok(flag |> BlasInt)
8181
A, ipiv, flag
8282
end
8383

@@ -100,7 +100,7 @@ function Xgetrs!(trans::Char, A::StridedCuMatrix{T}, ipiv::CuVector{Int64}, B::S
100100
cusolverDnXgetrs(dh, params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, dh.info)
101101

102102
flag = @allowscalar dh.info[1]
103-
chkargsok(BlasInt(flag))
103+
chkargsok(flag |> BlasInt)
104104
B
105105
end
106106

@@ -125,7 +125,7 @@ function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat}
125125
end
126126

127127
flag = @allowscalar dh.info[1]
128-
chkargsok(BlasInt(flag))
128+
chkargsok(flag |> BlasInt)
129129
A, tau
130130
end
131131

@@ -159,7 +159,7 @@ function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Stride
159159
end
160160

161161
flag = @allowscalar dh.info[1]
162-
chkargsok(BlasInt(flag))
162+
chkargsok(flag |> BlasInt)
163163
B
164164
end
165165

@@ -184,7 +184,7 @@ function trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasF
184184
end
185185

186186
flag = @allowscalar dh.info[1]
187-
chkargsok(BlasInt(flag))
187+
chkargsok(flag |> BlasInt)
188188
A
189189
end
190190

@@ -262,7 +262,7 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla
262262
end
263263

264264
flag = @allowscalar dh.info[1]
265-
chklapackerror(BlasInt(flag))
265+
chklapackerror(flag |> BlasInt)
266266
U, Σ, Vt
267267
end
268268

@@ -314,7 +314,7 @@ function Xgesvdp!(jobz::Char, econ::Int, A::StridedCuMatrix{T}) where {T <: Blas
314314
end
315315

316316
flag = @allowscalar dh.info[1]
317-
chklapackerror(BlasInt(flag))
317+
chklapackerror(flag |> BlasInt)
318318
if jobz == 'N'
319319
return Σ, h_err_sigma[]
320320
elseif jobz == 'V'
@@ -343,7 +343,7 @@ function Xgesvdr!(jobu::Char, jobv::Char, A::StridedCuMatrix{T}, k::Integer;
343343
elseif jobv == 'N'
344344
CU_NULL
345345
else
346-
throw(ArgumentError("jobv is incorrect. The values accepted are S' and 'N'."))
346+
throw(ArgumentError("jobv is incorrect. The values accepted are 'S' and 'N'."))
347347
end
348348
lda = max(1, stride(A, 2))
349349
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
@@ -367,7 +367,7 @@ function Xgesvdr!(jobu::Char, jobv::Char, A::StridedCuMatrix{T}, k::Integer;
367367
end
368368

369369
flag = @allowscalar dh.info[1]
370-
chklapackerror(BlasInt(flag))
370+
chklapackerror(flag |> BlasInt)
371371
U, Σ, V
372372
end
373373

@@ -395,7 +395,7 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas
395395
end
396396

397397
flag = @allowscalar dh.info[1]
398-
chkargsok(BlasInt(flag))
398+
chkargsok(flag |> BlasInt)
399399

400400
if jobz == 'N'
401401
return W
@@ -436,7 +436,7 @@ function Xsyevdx!(jobz::Char, range::Char, uplo::Char, A::StridedCuMatrix{T};
436436
end
437437

438438
flag = @allowscalar dh.info[1]
439-
chkargsok(BlasInt(flag))
439+
chkargsok(flag |> BlasInt)
440440

441441
if jobz == 'N'
442442
return W, h_meig[]
@@ -445,6 +445,51 @@ function Xsyevdx!(jobz::Char, range::Char, uplo::Char, A::StridedCuMatrix{T};
445445
end
446446
end
447447

448+
# Xgeev
449+
function Xgeev!(jobvl::Char, jobvr::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
450+
n = checksquare(A)
451+
VL = if jobvl == 'V'
452+
CuMatrix{T}(undef, n, n)
453+
elseif jobvl == 'N'
454+
CU_NULL
455+
else
456+
throw(ArgumentError("jobvl is incorrect. The values accepted are 'V' and 'N'."))
457+
end
458+
C = T <: Real ? Complex{T} : T
459+
W = CuVector{C}(undef, n)
460+
VR = if jobvr == 'V'
461+
CuMatrix{T}(undef, n, n)
462+
elseif jobvr == 'N'
463+
CU_NULL
464+
else
465+
throw(ArgumentError("jobvr is incorrect. The values accepted are 'V' and 'N'."))
466+
end
467+
lda = max(1, stride(A, 2))
468+
ldvl = VL == CU_NULL ? 1 : max(1, stride(VL, 2))
469+
ldvr = VR == CU_NULL ? 1 : max(1, stride(VR, 2))
470+
params = CuSolverParameters()
471+
dh = dense_handle()
472+
473+
function bufferSize()
474+
out_cpu = Ref{Csize_t}(0)
475+
out_gpu = Ref{Csize_t}(0)
476+
cusolverDnXgeev_bufferSize(dh, params, jobvl, jobvr, n, T, A,
477+
lda, C, W, T, VL, ldvl, T, VR, ldvr,
478+
T, out_gpu, out_cpu)
479+
out_gpu[], out_cpu[]
480+
end
481+
with_workspaces(dh.workspace_gpu, dh.workspace_cpu, bufferSize()...) do buffer_gpu, buffer_cpu
482+
cusolverDnXgeev(dh, params, jobvl, jobvr, n, T, A, lda, C,
483+
W, T, VL, ldvl, T, VR, ldvr, T, buffer_gpu,
484+
sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), dh.info)
485+
end
486+
487+
flag = @allowscalar dh.info[1]
488+
chkargsok(flag |> BlasInt)
489+
490+
return W, VL, VR
491+
end
492+
448493
# LAPACK
449494
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
450495
@eval begin

test/libraries/cusolver/dense_generic.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,33 @@ n = 10
66
p = 5
77

88
@testset "cusolver -- generic API -- $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
9+
if CUSOLVER.version() >= v"11.7.1"
10+
@testset "geev!" begin
11+
A = rand(elty,n,n)
12+
d_A = CuMatrix(A)
13+
d_B = copy(d_A)
14+
W, VL, VR = CUSOLVER.Xgeev!('N', 'V', d_A)
15+
if elty <: Complex
16+
@test d_B * VR VR * Diagonal(W)
17+
else
18+
h_W = collect(W)
19+
i = 1
20+
while i <= n
21+
if h_W[i].im zero(elty)
22+
@test d_B * VR[:,i] h_W[i].re * VR[:,i]
23+
i = i + 1
24+
else
25+
V1 = VR[:,i] + im * VR[:,i+1]
26+
@test d_B * V1 h_W[i] * V1
27+
V2 = VR[:,i] - im * VR[:,i+1]
28+
@test d_B * V2 h_W[i+1] * V2
29+
i = i + 2
30+
end
31+
end
32+
end
33+
end
34+
end
35+
936
if CUSOLVER.version() >= v"11.6.0"
1037
@testset "larft!" begin
1138
@testset "direct = $direct" for direct in ('F', 'B')

0 commit comments

Comments
 (0)