Skip to content

Commit 77c51eb

Browse files
committed
[oneMKL] Interface batched version of lapack routines
1 parent 5ee8dd9 commit 77c51eb

File tree

4 files changed

+212
-45
lines changed

4 files changed

+212
-45
lines changed

lib/mkl/wrappers_lapack.jl

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ for (bname, fname, elty) in ((:onemklSpotri_scratchpad_size, :onemklSpotri, :Flo
7070
end
7171
end
7272

73-
#sytrf
73+
# sytrf
7474
for (bname, fname, elty) in ((:onemklSsytrf_scratchpad_size, :onemklSsytrf, :Float32),
7575
(:onemklDsytrf_scratchpad_size, :onemklDsytrf, :Float64),
7676
(:onemklCsytrf_scratchpad_size, :onemklCsytrf, :ComplexF32),
@@ -402,6 +402,62 @@ for (jname, bname, fname, elty, relty) in ((:sygvd!, :onemklSsygvd_scratchpad_si
402402
end
403403
end
404404

405+
# potrf_batch
406+
for (bname, fname, elty) in ((:onemklSpotrf_batch_scratchpad_size, :onemklSpotrf_batch, :Float32),
407+
(:onemklDpotrf_batch_scratchpad_size, :onemklDpotrf_batch, :Float64),
408+
(:onemklCpotrf_batch_scratchpad_size, :onemklCpotrf_batch, :ComplexF32),
409+
(:onemklZpotrf_batch_scratchpad_size, :onemklZpotrf_batch, :ComplexF64))
410+
@eval begin
411+
function potrf_batched!(A::Vector{<:oneMatrix{$elty}})
412+
group_count = length(A)
413+
group_sizes = ones(Int64, group_count)
414+
uplo = [ONEMKL_UPLO_LOWER for i=1:group_count]
415+
n = [checksquare(A[i]) for i=1:group_count]
416+
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
417+
Aptrs = unsafe_batch(A)
418+
419+
queue = global_queue(context(A[1]), device(A[1]))
420+
scratchpad_size = $bname(sycl_queue(queue), uplo, n, lda, group_count, group_sizes)
421+
scratchpad = oneVector{$elty}(undef, scratchpad_size)
422+
$fname(sycl_queue(queue), uplo, n, Aptrs, lda, group_count, group_sizes, scratchpad, scratchpad_size)
423+
424+
unsafe_free!(Aptrs)
425+
426+
return A
427+
end
428+
end
429+
end
430+
431+
# potrs_batch
432+
for (bname, fname, elty) in ((:onemklSpotrs_batch_scratchpad_size, :onemklSpotrs_batch, :Float32),
433+
(:onemklDpotrs_batch_scratchpad_size, :onemklDpotrs_batch, :Float64),
434+
(:onemklCpotrs_batch_scratchpad_size, :onemklCpotrs_batch, :ComplexF32),
435+
(:onemklZpotrs_batch_scratchpad_size, :onemklZpotrs_batch, :ComplexF64))
436+
@eval begin
437+
function potrs_batched!(A::Vector{<:oneMatrix{$elty}}, B::Vector{<:oneMatrix{$elty}})
438+
group_count = length(A)
439+
group_sizes = ones(Int64, group_count)
440+
uplo = [ONEMKL_UPLO_LOWER for i=1:group_count]
441+
n = [checksquare(A[i]) for i=1:group_count]
442+
nrhs = [size(B[i], 2) for i=1:group_count]
443+
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
444+
ldb = [max(1, stride(B[i], 2)) for i=1:group_count]
445+
Aptrs = unsafe_batch(A)
446+
Bptrs = unsafe_batch(B)
447+
448+
queue = global_queue(context(A[1]), device(A[1]))
449+
scratchpad_size = $bname(sycl_queue(queue), uplo, n, nrhs, lda, ldb, group_count, group_sizes)
450+
scratchpad = oneVector{$elty}(undef, scratchpad_size)
451+
$fname(sycl_queue(queue), uplo, n, nrhs, Aptrs, lda, Bptrs, ldb, group_count, group_sizes, scratchpad, scratchpad_size)
452+
453+
unsafe_free!(Aptrs)
454+
unsafe_free!(Bptrs)
455+
456+
return A
457+
end
458+
end
459+
end
460+
405461
# getrf_batch
406462
for (bname, fname, elty) in ((:onemklSgetrf_batch_scratchpad_size, :onemklSgetrf_batch, :Float32),
407463
(:onemklDgetrf_batch_scratchpad_size, :onemklDgetrf_batch, :Float64),
@@ -490,6 +546,64 @@ for (bname, fname, elty) in ((:onemklSgetri_batch_scratchpad_size, :onemklSgetri
490546
end
491547
end
492548

549+
# geqrf_batch
550+
for (bname, fname, elty) in ((:onemklSgeqrf_batch_scratchpad_size, :onemklSgeqrf_batch, :Float32),
551+
(:onemklDgeqrf_batch_scratchpad_size, :onemklDgeqrf_batch, :Float64),
552+
(:onemklCgeqrf_batch_scratchpad_size, :onemklCgeqrf_batch, :ComplexF32),
553+
(:onemklZgeqrf_batch_scratchpad_size, :onemklZgeqrf_batch, :ComplexF64))
554+
@eval begin
555+
function geqrf_batched!(A::Vector{<:oneMatrix{$elty}})
556+
group_count = length(A)
557+
group_sizes = ones(Int64, group_count)
558+
m = [size(A[i], 1) for i=1:group_count]
559+
n = [size(A[i], 2) for i=1:group_count]
560+
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
561+
tau = [oneVector{$elty}(undef, min(m[i], n[i])) for i=1:group_count]
562+
Aptrs = unsafe_batch(A)
563+
tauptrs = unsafe_batch(tau)
564+
565+
queue = global_queue(context(A[1]), device(A[1]))
566+
scratchpad_size = $bname(sycl_queue(queue), m, n, lda, group_count, group_sizes)
567+
scratchpad = oneVector{$elty}(undef, scratchpad_size)
568+
$fname(sycl_queue(queue), m, n, Aptrs, lda, tauptrs, group_count, group_sizes, scratchpad, scratchpad_size)
569+
570+
unsafe_free!(Aptrs)
571+
unsafe_free!(tauptrs)
572+
573+
return tau, A
574+
end
575+
end
576+
end
577+
578+
# orgqr_batch and ungqr_batch
579+
for (bname, fname, elty) in ((:onemklSorgqr_batch_scratchpad_size, :onemklSorgqr_batch, :Float32),
580+
(:onemklDorgqr_batch_scratchpad_size, :onemklDorgqr_batch, :Float64),
581+
(:onemklCungqr_batch_scratchpad_size, :onemklCungqr_batch, :ComplexF32),
582+
(:onemklZungqr_batch_scratchpad_size, :onemklZungqr_batch, :ComplexF64))
583+
@eval begin
584+
function orgqr_batched!(A::Vector{<:oneMatrix{$elty}}, tau::Vector{<:oneVector{$elty}})
585+
group_count = length(A)
586+
group_sizes = ones(Int64, group_count)
587+
m = [size(A[i], 1) for i=1:group_count]
588+
n = [size(A[i], 2) for i=1:group_count]
589+
k = [min(m[i], n[i]) for i=1:group_count]
590+
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
591+
Aptrs = unsafe_batch(A)
592+
tauptrs = unsafe_batch(tau)
593+
594+
queue = global_queue(context(A[1]), device(A[1]))
595+
scratchpad_size = $bname(sycl_queue(queue), m, n, k, lda, group_count, group_sizes)
596+
scratchpad = oneVector{$elty}(undef, scratchpad_size)
597+
$fname(sycl_queue(queue), m, n, k, Aptrs, lda, tauptrs, group_count, group_sizes, scratchpad, scratchpad_size)
598+
599+
unsafe_free!(Aptrs)
600+
unsafe_free!(tauptrs)
601+
602+
return A
603+
end
604+
end
605+
end
606+
493607
# LAPACK
494608
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
495609
@eval begin

lib/support/liboneapi_support.jl

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4443,87 +4443,87 @@ function onemklSpotrf_batch(device_queue, uplo, n, a, lda, group_count, group_si
44434443
scratchpad, scratchpad_size)
44444444
@ccall liboneapi_support.onemklSpotrf_batch(device_queue::syclQueue_t,
44454445
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
4446-
a::Ptr{Ptr{Cfloat}}, lda::Ptr{Int64},
4446+
a::ZePtr{Ptr{Cfloat}}, lda::Ptr{Int64},
44474447
group_count::Int64, group_sizes::Ptr{Int64},
4448-
scratchpad::Ptr{Cfloat},
4448+
scratchpad::ZePtr{Cfloat},
44494449
scratchpad_size::Int64)::Cint
44504450
end
44514451

44524452
function onemklDpotrf_batch(device_queue, uplo, n, a, lda, group_count, group_sizes,
44534453
scratchpad, scratchpad_size)
44544454
@ccall liboneapi_support.onemklDpotrf_batch(device_queue::syclQueue_t,
44554455
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
4456-
a::Ptr{Ptr{Cdouble}}, lda::Ptr{Int64},
4456+
a::ZePtr{Ptr{Cdouble}}, lda::Ptr{Int64},
44574457
group_count::Int64, group_sizes::Ptr{Int64},
4458-
scratchpad::Ptr{Cdouble},
4458+
scratchpad::ZePtr{Cdouble},
44594459
scratchpad_size::Int64)::Cint
44604460
end
44614461

44624462
function onemklCpotrf_batch(device_queue, uplo, n, a, lda, group_count, group_sizes,
44634463
scratchpad, scratchpad_size)
44644464
@ccall liboneapi_support.onemklCpotrf_batch(device_queue::syclQueue_t,
44654465
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
4466-
a::Ptr{Ptr{ComplexF32}}, lda::Ptr{Int64},
4466+
a::ZePtr{Ptr{ComplexF32}}, lda::Ptr{Int64},
44674467
group_count::Int64, group_sizes::Ptr{Int64},
4468-
scratchpad::Ptr{ComplexF32},
4468+
scratchpad::ZePtr{ComplexF32},
44694469
scratchpad_size::Int64)::Cint
44704470
end
44714471

44724472
function onemklZpotrf_batch(device_queue, uplo, n, a, lda, group_count, group_sizes,
44734473
scratchpad, scratchpad_size)
44744474
@ccall liboneapi_support.onemklZpotrf_batch(device_queue::syclQueue_t,
44754475
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
4476-
a::Ptr{Ptr{ComplexF32}}, lda::Ptr{Int64},
4476+
a::ZePtr{Ptr{ComplexF64}}, lda::Ptr{Int64},
44774477
group_count::Int64, group_sizes::Ptr{Int64},
4478-
scratchpad::Ptr{ComplexF32},
4478+
scratchpad::ZePtr{ComplexF64},
44794479
scratchpad_size::Int64)::Cint
44804480
end
44814481

44824482
function onemklSpotrs_batch(device_queue, uplo, n, nrhs, a, lda, b, ldb, group_count,
44834483
group_sizes, scratchpad, scratchpad_size)
44844484
@ccall liboneapi_support.onemklSpotrs_batch(device_queue::syclQueue_t,
44854485
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
4486-
nrhs::Ptr{Int64}, a::Ptr{Ptr{Cfloat}},
4487-
lda::Ptr{Int64}, b::Ptr{Ptr{Cfloat}},
4486+
nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cfloat}},
4487+
lda::Ptr{Int64}, b::ZePtr{Ptr{Cfloat}},
44884488
ldb::Ptr{Int64}, group_count::Int64,
44894489
group_sizes::Ptr{Int64},
4490-
scratchpad::Ptr{Cfloat},
4490+
scratchpad::ZePtr{Cfloat},
44914491
scratchpad_size::Int64)::Cint
44924492
end
44934493

44944494
function onemklDpotrs_batch(device_queue, uplo, n, nrhs, a, lda, b, ldb, group_count,
44954495
group_sizes, scratchpad, scratchpad_size)
44964496
@ccall liboneapi_support.onemklDpotrs_batch(device_queue::syclQueue_t,
44974497
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
4498-
nrhs::Ptr{Int64}, a::Ptr{Ptr{Cdouble}},
4499-
lda::Ptr{Int64}, b::Ptr{Ptr{Cdouble}},
4498+
nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cdouble}},
4499+
lda::Ptr{Int64}, b::ZePtr{Ptr{Cdouble}},
45004500
ldb::Ptr{Int64}, group_count::Int64,
45014501
group_sizes::Ptr{Int64},
4502-
scratchpad::Ptr{Cdouble},
4502+
scratchpad::ZePtr{Cdouble},
45034503
scratchpad_size::Int64)::Cint
45044504
end
45054505

45064506
function onemklCpotrs_batch(device_queue, uplo, n, nrhs, a, lda, b, ldb, group_count,
45074507
group_sizes, scratchpad, scratchpad_size)
45084508
@ccall liboneapi_support.onemklCpotrs_batch(device_queue::syclQueue_t,
45094509
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
4510-
nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}},
4511-
lda::Ptr{Int64}, b::Ptr{Ptr{ComplexF32}},
4510+
nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF32}},
4511+
lda::Ptr{Int64}, b::ZePtr{Ptr{ComplexF32}},
45124512
ldb::Ptr{Int64}, group_count::Int64,
45134513
group_sizes::Ptr{Int64},
4514-
scratchpad::Ptr{ComplexF32},
4514+
scratchpad::ZePtr{ComplexF32},
45154515
scratchpad_size::Int64)::Cint
45164516
end
45174517

45184518
function onemklZpotrs_batch(device_queue, uplo, n, nrhs, a, lda, b, ldb, group_count,
45194519
group_sizes, scratchpad, scratchpad_size)
45204520
@ccall liboneapi_support.onemklZpotrs_batch(device_queue::syclQueue_t,
45214521
uplo::Ptr{onemklUplo}, n::Ptr{Int64},
4522-
nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}},
4523-
lda::Ptr{Int64}, b::Ptr{Ptr{ComplexF32}},
4522+
nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF64}},
4523+
lda::Ptr{Int64}, b::ZePtr{Ptr{ComplexF64}},
45244524
ldb::Ptr{Int64}, group_count::Int64,
45254525
group_sizes::Ptr{Int64},
4526-
scratchpad::Ptr{ComplexF32},
4526+
scratchpad::ZePtr{ComplexF64},
45274527
scratchpad_size::Int64)::Cint
45284528
end
45294529

@@ -4697,43 +4697,43 @@ function onemklSorgqr_batch(device_queue, m, n, k, a, lda, tau, group_count, gro
46974697
scratchpad, scratchpad_size)
46984698
@ccall liboneapi_support.onemklSorgqr_batch(device_queue::syclQueue_t, m::Ptr{Int64},
46994699
n::Ptr{Int64}, k::Ptr{Int64},
4700-
a::Ptr{Ptr{Cfloat}}, lda::Ptr{Int64},
4701-
tau::Ptr{Ptr{Cfloat}}, group_count::Int64,
4700+
a::ZePtr{Ptr{Cfloat}}, lda::Ptr{Int64},
4701+
tau::ZePtr{Ptr{Cfloat}}, group_count::Int64,
47024702
group_sizes::Ptr{Int64},
4703-
scratchpad::Ptr{Cfloat},
4703+
scratchpad::ZePtr{Cfloat},
47044704
scratchpad_size::Int64)::Cint
47054705
end
47064706

47074707
function onemklDorgqr_batch(device_queue, m, n, k, a, lda, tau, group_count, group_sizes,
47084708
scratchpad, scratchpad_size)
47094709
@ccall liboneapi_support.onemklDorgqr_batch(device_queue::syclQueue_t, m::Ptr{Int64},
47104710
n::Ptr{Int64}, k::Ptr{Int64},
4711-
a::Ptr{Ptr{Cdouble}}, lda::Ptr{Int64},
4712-
tau::Ptr{Ptr{Cdouble}}, group_count::Int64,
4713-
group_sizes::Ptr{Int64},
4714-
scratchpad::Ptr{Cdouble},
4711+
a::ZePtr{Ptr{Cdouble}}, lda::Ptr{Int64},
4712+
tau::ZePtr{Ptr{Cdouble}},
4713+
group_count::Int64, group_sizes::Ptr{Int64},
4714+
scratchpad::ZePtr{Cdouble},
47154715
scratchpad_size::Int64)::Cint
47164716
end
47174717

47184718
function onemklCungqr_batch(device_queue, m, n, k, a, lda, tau, group_count, group_sizes,
47194719
scratchpad, scratchpad_size)
47204720
@ccall liboneapi_support.onemklCungqr_batch(device_queue::syclQueue_t, m::Ptr{Int64},
47214721
n::Ptr{Int64}, k::Ptr{Int64},
4722-
a::Ptr{Ptr{ComplexF32}}, lda::Ptr{Int64},
4723-
tau::Ptr{Ptr{ComplexF32}},
4722+
a::ZePtr{Ptr{ComplexF32}}, lda::Ptr{Int64},
4723+
tau::ZePtr{Ptr{ComplexF32}},
47244724
group_count::Int64, group_sizes::Ptr{Int64},
4725-
scratchpad::Ptr{ComplexF32},
4725+
scratchpad::ZePtr{ComplexF32},
47264726
scratchpad_size::Int64)::Cint
47274727
end
47284728

47294729
function onemklZungqr_batch(device_queue, m, n, k, a, lda, tau, group_count, group_sizes,
47304730
scratchpad, scratchpad_size)
47314731
@ccall liboneapi_support.onemklZungqr_batch(device_queue::syclQueue_t, m::Ptr{Int64},
47324732
n::Ptr{Int64}, k::Ptr{Int64},
4733-
a::Ptr{Ptr{ComplexF32}}, lda::Ptr{Int64},
4734-
tau::Ptr{Ptr{ComplexF32}},
4733+
a::ZePtr{Ptr{ComplexF64}}, lda::Ptr{Int64},
4734+
tau::ZePtr{Ptr{ComplexF64}},
47354735
group_count::Int64, group_sizes::Ptr{Int64},
4736-
scratchpad::Ptr{ComplexF32},
4736+
scratchpad::ZePtr{ComplexF64},
47374737
scratchpad_size::Int64)::Cint
47384738
end
47394739

res/support.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,25 @@ use_ccall_macro = true
566566
6 = "ZePtr{Ptr{T}}"
567567
9 = "ZePtr{T}"
568568

569+
[api.onemklXorgqr_batch.argtypes]
570+
5 = "ZePtr{Ptr{T}}"
571+
7 = "ZePtr{Ptr{T}}"
572+
10 = "ZePtr{T}"
573+
574+
[api.onemklXungqr_batch.argtypes]
575+
5 = "ZePtr{Ptr{T}}"
576+
7 = "ZePtr{Ptr{T}}"
577+
10 = "ZePtr{T}"
578+
579+
[api.onemklXpotrf_batch.argtypes]
580+
4 = "ZePtr{Ptr{T}}"
581+
8 = "ZePtr{T}"
582+
583+
[api.onemklXpotrs_batch.argtypes]
584+
5 = "ZePtr{Ptr{T}}"
585+
7 = "ZePtr{Ptr{T}}"
586+
11 = "ZePtr{T}"
587+
569588
[api.onemklXsyevd.argtypes]
570589
5 = "ZePtr{T}"
571590
7 = "ZePtr{T}"

0 commit comments

Comments
 (0)