@@ -34,7 +34,7 @@ function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
34
34
end
35
35
36
36
flag = @allowscalar dh. info[1 ]
37
- chkargsok (BlasInt ( flag) )
37
+ chkargsok (flag |> BlasInt )
38
38
A, flag
39
39
end
40
40
@@ -52,7 +52,7 @@ function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) whe
52
52
cusolverDnXpotrs (dh, params, uplo, n, nrhs, T, A, lda, T, B, ldb, dh. info)
53
53
54
54
flag = @allowscalar dh. info[1 ]
55
- chkargsok (BlasInt ( flag) )
55
+ chkargsok (flag |> BlasInt )
56
56
B
57
57
end
58
58
@@ -77,7 +77,7 @@ function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{Int64}) where {T <: BlasF
77
77
end
78
78
79
79
flag = @allowscalar dh. info[1 ]
80
- chkargsok (BlasInt ( flag) )
80
+ chkargsok (flag |> BlasInt )
81
81
A, ipiv, flag
82
82
end
83
83
@@ -100,7 +100,7 @@ function Xgetrs!(trans::Char, A::StridedCuMatrix{T}, ipiv::CuVector{Int64}, B::S
100
100
cusolverDnXgetrs (dh, params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, dh. info)
101
101
102
102
flag = @allowscalar dh. info[1 ]
103
- chkargsok (BlasInt ( flag) )
103
+ chkargsok (flag |> BlasInt )
104
104
B
105
105
end
106
106
@@ -125,7 +125,7 @@ function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat}
125
125
end
126
126
127
127
flag = @allowscalar dh. info[1 ]
128
- chkargsok (BlasInt ( flag) )
128
+ chkargsok (flag |> BlasInt )
129
129
A, tau
130
130
end
131
131
@@ -159,7 +159,7 @@ function sytrs!(uplo::Char, A::StridedCuMatrix{T}, p::CuVector{Int64}, B::Stride
159
159
end
160
160
161
161
flag = @allowscalar dh. info[1 ]
162
- chkargsok (BlasInt ( flag) )
162
+ chkargsok (flag |> BlasInt )
163
163
B
164
164
end
165
165
@@ -184,7 +184,7 @@ function trtri!(uplo::Char, diag::Char, A::StridedCuMatrix{T}) where {T <: BlasF
184
184
end
185
185
186
186
flag = @allowscalar dh. info[1 ]
187
- chkargsok (BlasInt ( flag) )
187
+ chkargsok (flag |> BlasInt )
188
188
A
189
189
end
190
190
@@ -262,7 +262,7 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla
262
262
end
263
263
264
264
flag = @allowscalar dh. info[1 ]
265
- chklapackerror (BlasInt ( flag) )
265
+ chklapackerror (flag |> BlasInt )
266
266
U, Σ, Vt
267
267
end
268
268
@@ -314,7 +314,7 @@ function Xgesvdp!(jobz::Char, econ::Int, A::StridedCuMatrix{T}) where {T <: Blas
314
314
end
315
315
316
316
flag = @allowscalar dh. info[1 ]
317
- chklapackerror (BlasInt ( flag) )
317
+ chklapackerror (flag |> BlasInt )
318
318
if jobz == ' N'
319
319
return Σ, h_err_sigma[]
320
320
elseif jobz == ' V'
@@ -343,7 +343,7 @@ function Xgesvdr!(jobu::Char, jobv::Char, A::StridedCuMatrix{T}, k::Integer;
343
343
elseif jobv == ' N'
344
344
CU_NULL
345
345
else
346
- throw (ArgumentError (" jobv is incorrect. The values accepted are S' and 'N'." ))
346
+ throw (ArgumentError (" jobv is incorrect. The values accepted are ' S' and 'N'." ))
347
347
end
348
348
lda = max (1 , stride (A, 2 ))
349
349
ldu = U == CU_NULL ? 1 : max (1 , stride (U, 2 ))
@@ -367,7 +367,7 @@ function Xgesvdr!(jobu::Char, jobv::Char, A::StridedCuMatrix{T}, k::Integer;
367
367
end
368
368
369
369
flag = @allowscalar dh. info[1 ]
370
- chklapackerror (BlasInt ( flag) )
370
+ chklapackerror (flag |> BlasInt )
371
371
U, Σ, V
372
372
end
373
373
@@ -395,7 +395,7 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas
395
395
end
396
396
397
397
flag = @allowscalar dh. info[1 ]
398
- chkargsok (BlasInt ( flag) )
398
+ chkargsok (flag |> BlasInt )
399
399
400
400
if jobz == ' N'
401
401
return W
@@ -436,7 +436,7 @@ function Xsyevdx!(jobz::Char, range::Char, uplo::Char, A::StridedCuMatrix{T};
436
436
end
437
437
438
438
flag = @allowscalar dh. info[1 ]
439
- chkargsok (BlasInt ( flag) )
439
+ chkargsok (flag |> BlasInt )
440
440
441
441
if jobz == ' N'
442
442
return W, h_meig[]
@@ -445,6 +445,51 @@ function Xsyevdx!(jobz::Char, range::Char, uplo::Char, A::StridedCuMatrix{T};
445
445
end
446
446
end
447
447
448
+ # Xgeev
449
+ function Xgeev! (jobvl:: Char , jobvr:: Char , A:: StridedCuMatrix{T} ) where {T <: BlasFloat }
450
+ n = checksquare (A)
451
+ VL = if jobvl == ' V'
452
+ CuMatrix {T} (undef, n, n)
453
+ elseif jobvl == ' N'
454
+ CU_NULL
455
+ else
456
+ throw (ArgumentError (" jobvl is incorrect. The values accepted are 'V' and 'N'." ))
457
+ end
458
+ C = T <: Real ? Complex{T} : T
459
+ W = CuVector {C} (undef, n)
460
+ VR = if jobvr == ' V'
461
+ CuMatrix {T} (undef, n, n)
462
+ elseif jobvr == ' N'
463
+ CU_NULL
464
+ else
465
+ throw (ArgumentError (" jobvr is incorrect. The values accepted are 'V' and 'N'." ))
466
+ end
467
+ lda = max (1 , stride (A, 2 ))
468
+ ldvl = VL == CU_NULL ? 1 : max (1 , stride (VL, 2 ))
469
+ ldvr = VR == CU_NULL ? 1 : max (1 , stride (VR, 2 ))
470
+ params = CuSolverParameters ()
471
+ dh = dense_handle ()
472
+
473
+ function bufferSize ()
474
+ out_cpu = Ref {Csize_t} (0 )
475
+ out_gpu = Ref {Csize_t} (0 )
476
+ cusolverDnXgeev_bufferSize (dh, params, jobvl, jobvr, n, T, A,
477
+ lda, C, W, T, VL, ldvl, T, VR, ldvr,
478
+ T, out_gpu, out_cpu)
479
+ out_gpu[], out_cpu[]
480
+ end
481
+ with_workspaces (dh. workspace_gpu, dh. workspace_cpu, bufferSize ()... ) do buffer_gpu, buffer_cpu
482
+ cusolverDnXgeev (dh, params, jobvl, jobvr, n, T, A, lda, C,
483
+ W, T, VL, ldvl, T, VR, ldvr, T, buffer_gpu,
484
+ sizeof (buffer_gpu), buffer_cpu, sizeof (buffer_cpu), dh. info)
485
+ end
486
+
487
+ flag = @allowscalar dh. info[1 ]
488
+ chkargsok (flag |> BlasInt)
489
+
490
+ return W, VL, VR
491
+ end
492
+
448
493
# LAPACK
449
494
for elty in (:Float32 , :Float64 , :ComplexF32 , :ComplexF64 )
450
495
@eval begin
0 commit comments