Skip to content

Commit 949a457

Browse files
committed
[oneMKL] Interface EVD and SVD decompositions
1 parent 9531fc0 commit 949a457

File tree

5 files changed

+311
-41
lines changed

5 files changed

+311
-41
lines changed

lib/mkl/utils.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,50 @@ function Base.convert(::Type{onemklLayout}, index::Char)
6464
end
6565
end
6666

67+
function Base.convert(::Type{onemklJobsvd}, job::Char)
68+
if job == 'N'
69+
return ONEMKL_JOBSVD_N
70+
elseif job == 'A'
71+
return ONEMKL_JOBSVD_A
72+
elseif job == 'O'
73+
return ONEMKL_JOBSVD_O
74+
elseif job == 'S'
75+
return ONEMKL_JOBSVD_S
76+
else
77+
throw(ArgumentError("Unknown job $job."))
78+
end
79+
end
80+
81+
function Base.convert(::Type{onemklJob}, job::Char)
82+
if job == 'N'
83+
return ONEMKL_JOB_N
84+
elseif job == 'V'
85+
return ONEMKL_JOB_V
86+
elseif job == 'U'
87+
return ONEMKL_JOB_U
88+
elseif job == 'A'
89+
return ONEMKL_JOB_A
90+
elseif job == 'S'
91+
return ONEMKL_JOB_S
92+
elseif job == 'O'
93+
return ONEMKL_JOB_O
94+
else
95+
throw(ArgumentError("Unknown job $job."))
96+
end
97+
end
98+
99+
function Base.convert(::Type{onemklRangev}, range::Char)
100+
if range == 'A'
101+
return ONEMKL_RANGEV_A
102+
elseif range == 'V'
103+
return ONEMKL_RANGEV_V
104+
elseif range == 'I'
105+
return ONEMKL_RANGEV_I
106+
else
107+
throw(ArgumentError("Unknown eigenvalue solver range $range."))
108+
end
109+
end
110+
67111
# create a batch of pointers in device memory from a batch of device arrays
68112
@inline function unsafe_batch(batch::Vector{<:oneArray{T}}) where {T}
69113
ptrs = pointer.(batch)

lib/mkl/wrappers_lapack.jl

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ for (bname, fname, elty) in ((:onemklSorgqr_scratchpad_size, :onemklSorgqr, :Flo
268268
end
269269
end
270270

271-
#gebrd
271+
# gebrd
272272
for (bname, fname, elty, relty) in ((:onemklSgebrd_scratchpad_size, :onemklSgebrd, :Float32, :Float32),
273273
(:onemklDgebrd_scratchpad_size, :onemklDgebrd, :Float64, :Float64),
274274
(:onemklCgebrd_scratchpad_size, :onemklCgebrd, :ComplexF32, :Float32),
@@ -280,7 +280,7 @@ for (bname, fname, elty, relty) in ((:onemklSgebrd_scratchpad_size, :onemklSgebr
280280

281281
k = min(m, n)
282282
D = oneVector{$relty}(undef, k)
283-
E = oneVector{$elty}(undef, k)
283+
E = oneVector{$relty}(undef, k-1)
284284
tauq = oneVector{$elty}(undef, k)
285285
taup = oneVector{$elty}(undef, k)
286286

@@ -294,6 +294,114 @@ for (bname, fname, elty, relty) in ((:onemklSgebrd_scratchpad_size, :onemklSgebr
294294
end
295295
end
296296

297+
# gesvd
298+
for (bname, fname, elty, relty) in ((:onemklSgesvd_scratchpad_size, :onemklSgesvd, :Float32, :Float32),
299+
(:onemklDgesvd_scratchpad_size, :onemklDgesvd, :Float64, :Float64),
300+
(:onemklCgesvd_scratchpad_size, :onemklCgesvd, :ComplexF32, :Float32),
301+
(:onemklZgesvd_scratchpad_size, :onemklZgesvd, :ComplexF64, :Float64))
302+
@eval begin
303+
function gesvd!(jobu::Char,
304+
jobvt::Char,
305+
A::oneStridedMatrix{$elty})
306+
m, n = size(A)
307+
lda = max(1, stride(A, 2))
308+
309+
U = if jobu === 'A'
310+
oneMatrix{$elty}(undef, m, m)
311+
elseif jobu == 'S' || jobu === 'O'
312+
oneMatrix{$elty}(undef, m, min(m, n))
313+
elseif jobu === 'N'
314+
oneMatrix{$elty}(undef, 0, 0) # Equivalence of CU_NULL?
315+
else
316+
error("jobu must be one of 'A', 'S', 'O', or 'N'")
317+
end
318+
ldu = U == oneMatrix{$elty}(undef, 0, 0) ? 1 : max(1, stride(U, 2))
319+
S = oneVector{$relty}(undef, min(m, n))
320+
321+
Vt = if jobvt === 'A'
322+
oneMatrix{$elty}(undef, n, n)
323+
elseif jobvt === 'S' || jobvt === 'O'
324+
oneMatrix{$elty}(undef, min(m, n), n)
325+
elseif jobvt === 'N'
326+
oneMatrix{$elty}(undef, 0, 0)
327+
else
328+
error("jobvt must be one of 'A', 'S', 'O', or 'N'")
329+
end
330+
ldvt = Vt == oneArray{$elty}(undef, 0, 0) ? 1 : max(1, stride(Vt, 2))
331+
332+
queue = global_queue(context(A), device(A))
333+
scratchpad_size = $bname(sycl_queue(queue), jobu, jobvt, m, n, lda, ldu, ldvt)
334+
scratchpad = oneVector{$elty}(undef, scratchpad_size)
335+
$fname(sycl_queue(queue), jobu, jobvt, m, n, A, lda, S, U, ldu, Vt, ldvt, scratchpad, scratchpad_size)
336+
337+
return U, S, Vt
338+
end
339+
end
340+
end
341+
342+
# syevd and heevd
343+
for (jname, bname, fname, elty, relty) in ((:syevd!, :onemklSsyevd_scratchpad_size, :onemklSsyevd, :Float32, :Float32),
344+
(:syevd!, :onemklDsyevd_scratchpad_size, :onemklDsyevd, :Float64, :Float64),
345+
(:heevd!, :onemklCheevd_scratchpad_size, :onemklCheevd, :ComplexF32, :Float32),
346+
(:heevd!, :onemklZheevd_scratchpad_size, :onemklZheevd, :ComplexF64, :Float64))
347+
@eval begin
348+
function $jname(jobz::Char,
349+
uplo::Char,
350+
A::oneStridedMatrix{$elty})
351+
chkuplo(uplo)
352+
n = checksquare(A)
353+
lda = max(1, stride(A, 2))
354+
W = oneVector{$relty}(undef, n)
355+
356+
queue = global_queue(context(A), device(A))
357+
scratchpad_size = $bname(sycl_queue(queue), jobz, uplo, n, lda)
358+
scratchpad = oneVector{$elty}(undef, scratchpad_size)
359+
$fname(sycl_queue(queue), jobz, uplo, n, A, lda, W, scratchpad, scratchpad_size)
360+
361+
if jobz == 'N'
362+
return W
363+
elseif jobz == 'V'
364+
return W, A
365+
end
366+
end
367+
end
368+
end
369+
370+
# sygvd and hegvd
371+
for (jname, bname, fname, elty, relty) in ((:sygvd!, :onemklSsygvd_scratchpad_size, :onemklSsygvd, :Float32, :Float32),
372+
(:sygvd!, :onemklDsygvd_scratchpad_size, :onemklDsygvd, :Float64, :Float64),
373+
(:hegvd!, :onemklChegvd_scratchpad_size, :onemklChegvd, :ComplexF32, :Float32),
374+
(:hegvd!, :onemklZhegvd_scratchpad_size, :onemklZhegvd, :ComplexF64, :Float64))
375+
@eval begin
376+
function $jname(itype::Int,
377+
jobz::Char,
378+
uplo::Char,
379+
A::oneStridedMatrix{$elty},
380+
B::oneStridedMatrix{$elty})
381+
chkuplo(uplo)
382+
nA, nB = checksquare(A, B)
383+
if nB != nA
384+
throw(DimensionMismatch("Dimensions of A ($nA, $nA) and B ($nB, $nB) must match!"))
385+
end
386+
n = nA
387+
lda = max(1, stride(A, 2))
388+
ldb = max(1, stride(B, 2))
389+
W = oneVector{$relty}(undef, n)
390+
391+
queue = global_queue(context(A), device(A))
392+
scratchpad_size = $bname(sycl_queue(queue), itype, jobz, uplo, n, lda, ldb)
393+
scratchpad = oneVector{$elty}(undef, scratchpad_size)
394+
$fname(sycl_queue(queue), itype, jobz, uplo, n, A, lda, B, ldb, W, scratchpad, scratchpad_size)
395+
396+
if jobz == 'N'
397+
return W
398+
elseif jobz == 'V'
399+
return W, A, B
400+
end
401+
end
402+
end
403+
end
404+
297405
# getrf_batch
298406
for (bname, fname, elty) in ((:onemklSgetrf_batch_scratchpad_size, :onemklSgetrf_batch, :Float32),
299407
(:onemklDgetrf_batch_scratchpad_size, :onemklDgetrf_batch, :Float64),
@@ -364,5 +472,35 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
364472
LinearAlgebra.LAPACK.getrs!(trans::Char, A::oneStridedMatrix{$elty}, ipiv::oneStridedVector{Int64}, B::oneStridedVecOrMat{$elty}) = oneMKL.getrs!(trans, A, ipiv, B)
365473
LinearAlgebra.LAPACK.ormqr!(side::Char, trans::Char, A::oneStridedMatrix{$elty}, tau::oneStridedVector{$elty}, C::oneStridedVecOrMat{$elty}) = oneMKL.ormqr!(side, trans, A, tau, C)
366474
LinearAlgebra.LAPACK.orgqr!(A::oneStridedMatrix{$elty}, tau::oneStridedVector{$elty}) = oneMKL.orgqr!(A, tau)
475+
LinearAlgebra.LAPACK.gebrd!(A::oneStridedMatrix{$elty}) = oneMKL.gebrd!(A)
476+
LinearAlgebra.LAPACK.gesvd!(jobu::Char, jobvt::Char, A::oneStridedMatrix{$elty}) = oneMKL.gesvd!(jobu, jobvt, A)
477+
end
478+
end
479+
480+
for elty in (:Float32, :Float64)
481+
@eval begin
482+
LinearAlgebra.LAPACK.syev!(jobz::Char, uplo::Char, A::oneStridedMatrix{$elty}) = oneMKL.syevd!(jobz, uplo, A)
483+
LinearAlgebra.LAPACK.sygvd!(itype::Int, jobz::Char, uplo::Char, A::oneStridedMatrix{$elty}, B::oneStridedMatrix{$elty}) = oneMKL.sygvd!(itype, jobz, uplo, A, B)
484+
end
485+
end
486+
487+
for elty in (:ComplexF32, :ComplexF64)
488+
@eval begin
489+
LinearAlgebra.LAPACK.syev!(jobz::Char, uplo::Char, A::oneStridedMatrix{$elty}) = oneMKL.heevd!(jobz, uplo, A)
490+
LinearAlgebra.LAPACK.sygvd!(itype::Int, jobz::Char, uplo::Char, A::oneStridedMatrix{$elty}, B::oneStridedMatrix{$elty}) = oneMKL.hegvd!(itype, jobz, uplo, A, B)
491+
end
492+
end
493+
494+
if VERSION >= v"1.10"
495+
for elty in (:Float32, :Float64)
496+
@eval begin
497+
LinearAlgebra.LAPACK.syevd!(jobz::Char, uplo::Char, A::oneStridedMatrix{$elty}) = oneMKL.syevd!(jobz, uplo, A)
498+
end
499+
end
500+
501+
for elty in (:ComplexF32, :ComplexF64)
502+
@eval begin
503+
LinearAlgebra.LAPACK.syevd!(jobz::Char, uplo::Char, A::oneStridedMatrix{$elty}) = oneMKL.heevd!(jobz, uplo, A)
504+
end
367505
end
368506
end

lib/support/liboneapi_support.jl

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2637,7 +2637,7 @@ function onemklCgesvd(device_queue, jobu, jobvt, m, n, a, lda, s, u, ldu, vt, ld
26372637
@ccall liboneapi_support.onemklCgesvd(device_queue::syclQueue_t, jobu::onemklJobsvd,
26382638
jobvt::onemklJobsvd, m::Int64, n::Int64,
26392639
a::ZePtr{ComplexF32}, lda::Int64,
2640-
s::ZePtr{ComplexF32}, u::ZePtr{Float32},
2640+
s::ZePtr{Float32}, u::ZePtr{ComplexF32},
26412641
ldu::Int64, vt::ZePtr{ComplexF32}, ldvt::Int64,
26422642
scratchpad::ZePtr{ComplexF32},
26432643
scratchpad_size::Int64)::Cint
@@ -3496,17 +3496,17 @@ end
34963496

34973497
function onemklCheevd(device_queue, jobz, uplo, n, a, lda, w, scratchpad, scratchpad_size)
34983498
@ccall liboneapi_support.onemklCheevd(device_queue::syclQueue_t, jobz::onemklJob,
3499-
uplo::onemklUplo, n::Int64, a::Ptr{ComplexF32},
3500-
lda::Int64, w::Ptr{Cfloat},
3501-
scratchpad::Ptr{ComplexF32},
3499+
uplo::onemklUplo, n::Int64, a::ZePtr{ComplexF32},
3500+
lda::Int64, w::ZePtr{Float32},
3501+
scratchpad::ZePtr{ComplexF32},
35023502
scratchpad_size::Int64)::Cint
35033503
end
35043504

35053505
function onemklZheevd(device_queue, jobz, uplo, n, a, lda, w, scratchpad, scratchpad_size)
35063506
@ccall liboneapi_support.onemklZheevd(device_queue::syclQueue_t, jobz::onemklJob,
3507-
uplo::onemklUplo, n::Int64, a::Ptr{ComplexF32},
3508-
lda::Int64, w::Ptr{Cdouble},
3509-
scratchpad::Ptr{ComplexF32},
3507+
uplo::onemklUplo, n::Int64, a::ZePtr{ComplexF64},
3508+
lda::Int64, w::ZePtr{Float64},
3509+
scratchpad::ZePtr{ComplexF64},
35103510
scratchpad_size::Int64)::Cint
35113511
end
35123512

@@ -3528,19 +3528,19 @@ function onemklChegvd(device_queue, itype, jobz, uplo, n, a, lda, b, ldb, w, scr
35283528
scratchpad_size)
35293529
@ccall liboneapi_support.onemklChegvd(device_queue::syclQueue_t, itype::Int64,
35303530
jobz::onemklJob, uplo::onemklUplo, n::Int64,
3531-
a::Ptr{ComplexF32}, lda::Int64,
3532-
b::Ptr{ComplexF32}, ldb::Int64, w::Ptr{Cfloat},
3533-
scratchpad::Ptr{ComplexF32},
3531+
a::ZePtr{ComplexF32}, lda::Int64,
3532+
b::ZePtr{ComplexF32}, ldb::Int64,
3533+
w::ZePtr{Float32}, scratchpad::ZePtr{ComplexF32},
35343534
scratchpad_size::Int64)::Cint
35353535
end
35363536

35373537
function onemklZhegvd(device_queue, itype, jobz, uplo, n, a, lda, b, ldb, w, scratchpad,
35383538
scratchpad_size)
35393539
@ccall liboneapi_support.onemklZhegvd(device_queue::syclQueue_t, itype::Int64,
35403540
jobz::onemklJob, uplo::onemklUplo, n::Int64,
3541-
a::Ptr{ComplexF32}, lda::Int64,
3542-
b::Ptr{ComplexF32}, ldb::Int64, w::Ptr{Cdouble},
3543-
scratchpad::Ptr{ComplexF32},
3541+
a::ZePtr{ComplexF64}, lda::Int64,
3542+
b::ZePtr{ComplexF64}, ldb::Int64,
3543+
w::ZePtr{Float64}, scratchpad::ZePtr{ComplexF64},
35443544
scratchpad_size::Int64)::Cint
35453545
end
35463546

@@ -3790,17 +3790,17 @@ end
37903790

37913791
function onemklDsyevd(device_queue, jobz, uplo, n, a, lda, w, scratchpad, scratchpad_size)
37923792
@ccall liboneapi_support.onemklDsyevd(device_queue::syclQueue_t, jobz::onemklJob,
3793-
uplo::onemklUplo, n::Int64, a::Ptr{Cdouble},
3794-
lda::Int64, w::Ptr{Cdouble},
3795-
scratchpad::Ptr{Cdouble},
3793+
uplo::onemklUplo, n::Int64, a::ZePtr{Cdouble},
3794+
lda::Int64, w::ZePtr{Cdouble},
3795+
scratchpad::ZePtr{Cdouble},
37963796
scratchpad_size::Int64)::Cint
37973797
end
37983798

37993799
function onemklSsyevd(device_queue, jobz, uplo, n, a, lda, w, scratchpad, scratchpad_size)
38003800
@ccall liboneapi_support.onemklSsyevd(device_queue::syclQueue_t, jobz::onemklJob,
3801-
uplo::onemklUplo, n::Int64, a::Ptr{Cfloat},
3802-
lda::Int64, w::Ptr{Cfloat},
3803-
scratchpad::Ptr{Cfloat},
3801+
uplo::onemklUplo, n::Int64, a::ZePtr{Cfloat},
3802+
lda::Int64, w::ZePtr{Cfloat},
3803+
scratchpad::ZePtr{Cfloat},
38043804
scratchpad_size::Int64)::Cint
38053805
end
38063806

@@ -3868,19 +3868,19 @@ function onemklDsygvd(device_queue, itype, jobz, uplo, n, a, lda, b, ldb, w, scr
38683868
scratchpad_size)
38693869
@ccall liboneapi_support.onemklDsygvd(device_queue::syclQueue_t, itype::Int64,
38703870
jobz::onemklJob, uplo::onemklUplo, n::Int64,
3871-
a::Ptr{Cdouble}, lda::Int64, b::Ptr{Cdouble},
3872-
ldb::Int64, w::Ptr{Cdouble},
3873-
scratchpad::Ptr{Cdouble},
3871+
a::ZePtr{Cdouble}, lda::Int64, b::ZePtr{Cdouble},
3872+
ldb::Int64, w::ZePtr{Cdouble},
3873+
scratchpad::ZePtr{Cdouble},
38743874
scratchpad_size::Int64)::Cint
38753875
end
38763876

38773877
function onemklSsygvd(device_queue, itype, jobz, uplo, n, a, lda, b, ldb, w, scratchpad,
38783878
scratchpad_size)
38793879
@ccall liboneapi_support.onemklSsygvd(device_queue::syclQueue_t, itype::Int64,
38803880
jobz::onemklJob, uplo::onemklUplo, n::Int64,
3881-
a::Ptr{Cfloat}, lda::Int64, b::Ptr{Cfloat},
3882-
ldb::Int64, w::Ptr{Cfloat},
3883-
scratchpad::Ptr{Cfloat},
3881+
a::ZePtr{Cfloat}, lda::Int64, b::ZePtr{Cfloat},
3882+
ldb::Int64, w::ZePtr{Cfloat},
3883+
scratchpad::ZePtr{Cfloat},
38843884
scratchpad_size::Int64)::Cint
38853885
end
38863886

res/support.toml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,3 +559,36 @@ use_ccall_macro = true
559559
4 = "ZePtr{Ptr{T}}"
560560
6 = "ZePtr{Ptr{T}}"
561561
9 = "ZePtr{T}"
562+
563+
[api.onemklXsyevd.argtypes]
564+
5 = "ZePtr{T}"
565+
7 = "ZePtr{T}"
566+
8 = "ZePtr{T}"
567+
568+
[api.onemklCheevd.argtypes]
569+
5 = "ZePtr{ComplexF32}"
570+
7 = "ZePtr{Float32}"
571+
8 = "ZePtr{ComplexF32}"
572+
573+
[api.onemklZheevd.argtypes]
574+
5 = "ZePtr{ComplexF64}"
575+
7 = "ZePtr{Float64}"
576+
8 = "ZePtr{ComplexF64}"
577+
578+
[api.onemklXsygvd.argtypes]
579+
6 = "ZePtr{T}"
580+
8 = "ZePtr{T}"
581+
10 = "ZePtr{T}"
582+
11 = "ZePtr{T}"
583+
584+
[api.onemklChegvd.argtypes]
585+
6 = "ZePtr{ComplexF32}"
586+
8 = "ZePtr{ComplexF32}"
587+
10 = "ZePtr{Float32}"
588+
11 = "ZePtr{ComplexF32}"
589+
590+
[api.onemklZhegvd.argtypes]
591+
6 = "ZePtr{ComplexF64}"
592+
8 = "ZePtr{ComplexF64}"
593+
10 = "ZePtr{Float64}"
594+
11 = "ZePtr{ComplexF64}"

0 commit comments

Comments
 (0)