Skip to content

Commit 5ee8dd9

Browse files
committed
[oneMKL] Interface getrs_batched!
1 parent b6a393f commit 5ee8dd9

File tree

4 files changed

+72
-16
lines changed

4 files changed

+72
-16
lines changed

lib/mkl/wrappers_lapack.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,38 @@ for (bname, fname, elty) in ((:onemklSgetrf_batch_scratchpad_size, :onemklSgetrf
431431
end
432432
end
433433

434+
# getrs_batch
435+
for (bname, fname, elty) in ((:onemklSgetrs_batch_scratchpad_size, :onemklSgetrs_batch, :Float32),
436+
(:onemklDgetrs_batch_scratchpad_size, :onemklDgetrs_batch, :Float64),
437+
(:onemklCgetrs_batch_scratchpad_size, :onemklCgetrs_batch, :ComplexF32),
438+
(:onemklZgetrs_batch_scratchpad_size, :onemklZgetrs_batch, :ComplexF64))
439+
@eval begin
440+
function getrs_batched!(A::Vector{<:oneMatrix{$elty}}, ipiv::Vector{<:oneVector{Int64}}, B::Vector{<:oneMatrix{$elty}})
441+
group_count = length(A)
442+
group_sizes = ones(Int64, group_count)
443+
trans = [ONEMKL_TRANSPOSE_NONTRANS for i=1:group_count]
444+
n = [checksquare(A[i]) for i=1:group_count]
445+
nrhs = [size(B[i], 2) for i=1:group_count]
446+
lda = [max(1, stride(A[i], 2)) for i=1:group_count]
447+
ldb = [max(1, stride(B[i], 2)) for i=1:group_count]
448+
Aptrs = unsafe_batch(A)
449+
Bptrs = unsafe_batch(B)
450+
ipivptrs = unsafe_batch(ipiv)
451+
452+
queue = global_queue(context(A[1]), device(A[1]))
453+
scratchpad_size = $bname(sycl_queue(queue), trans, n, nrhs, lda, ldb, group_count, group_sizes)
454+
scratchpad = oneVector{$elty}(undef, scratchpad_size)
455+
$fname(sycl_queue(queue), trans, n, nrhs, Aptrs, lda, ipivptrs, Bptrs, ldb, group_count, group_sizes, scratchpad, scratchpad_size)
456+
457+
unsafe_free!(Aptrs)
458+
unsafe_free!(Bptrs)
459+
unsafe_free!(ipivptrs)
460+
461+
return B
462+
end
463+
end
464+
end
465+
434466
# getri_batch
435467
for (bname, fname, elty) in ((:onemklSgetri_batch_scratchpad_size, :onemklSgetri_batch, :Float32),
436468
(:onemklDgetri_batch_scratchpad_size, :onemklDgetri_batch, :Float64),

lib/support/liboneapi_support.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4567,47 +4567,47 @@ function onemklSgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb,
45674567
group_sizes, scratchpad, scratchpad_size)
45684568
@ccall liboneapi_support.onemklSgetrs_batch(device_queue::syclQueue_t,
45694569
trans::Ptr{onemklTranspose}, n::Ptr{Int64},
4570-
nrhs::Ptr{Int64}, a::Ptr{Ptr{Cfloat}},
4571-
lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}},
4572-
b::Ptr{Ptr{Cfloat}}, ldb::Ptr{Int64},
4570+
nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cfloat}},
4571+
lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}},
4572+
b::ZePtr{Ptr{Cfloat}}, ldb::Ptr{Int64},
45734573
group_count::Int64, group_sizes::Ptr{Int64},
4574-
scratchpad::Ptr{Cfloat},
4574+
scratchpad::ZePtr{Cfloat},
45754575
scratchpad_size::Int64)::Cint
45764576
end
45774577

45784578
function onemklDgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_count,
45794579
group_sizes, scratchpad, scratchpad_size)
45804580
@ccall liboneapi_support.onemklDgetrs_batch(device_queue::syclQueue_t,
45814581
trans::Ptr{onemklTranspose}, n::Ptr{Int64},
4582-
nrhs::Ptr{Int64}, a::Ptr{Ptr{Cdouble}},
4583-
lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}},
4584-
b::Ptr{Ptr{Cdouble}}, ldb::Ptr{Int64},
4582+
nrhs::Ptr{Int64}, a::ZePtr{Ptr{Cdouble}},
4583+
lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}},
4584+
b::ZePtr{Ptr{Cdouble}}, ldb::Ptr{Int64},
45854585
group_count::Int64, group_sizes::Ptr{Int64},
4586-
scratchpad::Ptr{Cdouble},
4586+
scratchpad::ZePtr{Cdouble},
45874587
scratchpad_size::Int64)::Cint
45884588
end
45894589

45904590
function onemklCgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_count,
45914591
group_sizes, scratchpad, scratchpad_size)
45924592
@ccall liboneapi_support.onemklCgetrs_batch(device_queue::syclQueue_t,
45934593
trans::Ptr{onemklTranspose}, n::Ptr{Int64},
4594-
nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}},
4595-
lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}},
4596-
b::Ptr{Ptr{ComplexF32}}, ldb::Ptr{Int64},
4594+
nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF32}},
4595+
lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}},
4596+
b::ZePtr{Ptr{ComplexF32}}, ldb::Ptr{Int64},
45974597
group_count::Int64, group_sizes::Ptr{Int64},
4598-
scratchpad::Ptr{ComplexF32},
4598+
scratchpad::ZePtr{ComplexF32},
45994599
scratchpad_size::Int64)::Cint
46004600
end
46014601

46024602
function onemklZgetrs_batch(device_queue, trans, n, nrhs, a, lda, ipiv, b, ldb, group_count,
46034603
group_sizes, scratchpad, scratchpad_size)
46044604
@ccall liboneapi_support.onemklZgetrs_batch(device_queue::syclQueue_t,
46054605
trans::Ptr{onemklTranspose}, n::Ptr{Int64},
4606-
nrhs::Ptr{Int64}, a::Ptr{Ptr{ComplexF32}},
4607-
lda::Ptr{Int64}, ipiv::Ptr{Ptr{Int64}},
4608-
b::Ptr{Ptr{ComplexF32}}, ldb::Ptr{Int64},
4606+
nrhs::Ptr{Int64}, a::ZePtr{Ptr{ComplexF64}},
4607+
lda::Ptr{Int64}, ipiv::ZePtr{Ptr{Int64}},
4608+
b::ZePtr{Ptr{ComplexF64}}, ldb::Ptr{Int64},
46094609
group_count::Int64, group_sizes::Ptr{Int64},
4610-
scratchpad::Ptr{ComplexF32},
4610+
scratchpad::ZePtr{ComplexF64},
46114611
scratchpad_size::Int64)::Cint
46124612
end
46134613

res/support.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,12 @@ use_ccall_macro = true
550550
6 = "ZePtr{Ptr{Int64}}"
551551
9 = "ZePtr{T}"
552552

553+
[api.onemklXgetrs_batch.argtypes]
554+
5 = "ZePtr{Ptr{T}}"
555+
7 = "ZePtr{Ptr{Int64}}"
556+
8 = "ZePtr{Ptr{T}}"
557+
12 = "ZePtr{T}"
558+
553559
[api.onemklXgetri_batch.argtypes]
554560
3 = "ZePtr{Ptr{T}}"
555561
5 = "ZePtr{Ptr{Int64}}"

test/onemkl.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,24 @@ end
13001300
end
13011301
end
13021302

1303+
@testset "getrs_batched!" begin
1304+
bA = [rand(elty, m, m) for i in 1:p]
1305+
bB = [rand(elty, m, n) for i in 1:p]
1306+
d_bA = oneMatrix{elty}[]
1307+
d_bB = oneMatrix{elty}[]
1308+
for i in 1:p
1309+
push!(d_bA, oneMatrix(bA[i]))
1310+
push!(d_bB, oneMatrix(bB[i]))
1311+
end
1312+
1313+
d_ipiv, d_bA = oneMKL.getrf_batched!(d_bA)
1314+
d_bX = oneMKL.getrs_batched!(d_bA, d_ipiv, d_bB)
1315+
h_bX = [collect(d_bX[i]) for i in 1:p]
1316+
for i = 1:p
1317+
@test bA[i] * hbX[i] bB[i]
1318+
end
1319+
end
1320+
13031321
@testset "gebrd!" begin
13041322
A = rand(elty,m,n)
13051323
d_A = oneArray(A)

0 commit comments

Comments
 (0)