Skip to content

Commit 24dc33b

Browse files
committed
Update the C interface to avoid conflicts between buffer and usm routines
1 parent 39eaeee commit 24dc33b

File tree

6 files changed

+499
-480
lines changed

6 files changed

+499
-480
lines changed

deps/generate_interfaces.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,6 @@ function generate_headers(library::String, filename::String, output::String)
222222
copy_header = replace(copy_header, "typename fp_type::value_type" => version_types_header[blas_version])
223223
copy_header = replace(copy_header, "fp_type" => version_types_header[blas_version])
224224
copy_header = replace(copy_header, name_routine => "onemkl$(blas_version)$(name_routine)")
225-
copy_header = replace(copy_header, "void onemkl" => "int onemkl")
226225
push!(signatures, (copy_header, name_routine, blas_version, template))
227226
end
228227
else
@@ -232,9 +231,7 @@ function generate_headers(library::String, filename::String, output::String)
232231
occursin("int64_t", header) && (suffix = "_64")
233232
end
234233
header = replace(header, "$(name_routine)(" => "onemkl$(version)$(name_routine)$(suffix)(")
235-
if name_routine ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")
236-
header = replace(header, "void onemkl" => "int onemkl")
237-
end
234+
header = replace(header, "void onemkl" => "int onemkl")
238235
if library == "sparse"
239236
if occursin("std::complex", header)
240237
(version == 'C') && (header = replace(header, "std::complex " => "float _Complex "))
@@ -366,20 +363,20 @@ function generate_cpp(library::String, filename::String, output::String)
366363
write(oneapi_cpp, "extern \"C\" $header {\n")
367364
if template
368365
type = version_types[version]
369-
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
366+
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters, {});\n")
370367
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
371368
else
372-
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")
373-
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters);\n")
369+
if name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")
370+
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters, {});\n")
374371
else
375372
write(oneapi_cpp, " oneapi::mkl::$library::$variant$name($parameters);\n")
376373
end
377374
end
378375
if occursin("scratchpad_size", name)
379376
write(oneapi_cpp, " return scratchpad_size;\n")
380377
else
381-
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
382-
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data", "set_csr_data")) && write(oneapi_cpp, " return 0;\n")
378+
(name ("init_matrix_handle", "init_matmat_descr", "release_matmat_descr", "set_matmat_data")) && write(oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n")
379+
write(oneapi_cpp, " return 0;\n")
383380
end
384381
write(oneapi_cpp, "}")
385382
write(oneapi_cpp, "\n\n")

deps/onemkl_epilogue.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ int clean_gpu_caches();
1212
}
1313
}
1414

15-
extern "C" void onemklDestroy() {
15+
extern "C" int onemklDestroy() {
1616
oneapi::mkl::gpu::clean_gpu_caches();
17+
return 0;
1718
}

deps/onemkl_epilogue.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
void onemklDestroy(void);
1+
int onemklDestroy(void);
22
#ifdef __cplusplus
33
}
44
#endif

deps/onemkl_prologue.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ extern "C" int onemklHgemmBatched(syclQueue_t device_queue, onemklTranspose tran
328328
reinterpret_cast<const sycl::half **>(&a[0]), lda,
329329
reinterpret_cast<const sycl::half **>(&b[0]), ldb,
330330
reinterpret_cast<sycl::half *>(beta), reinterpret_cast<sycl::half **>(&c[0]),
331-
ldc, group_count, group_size);
331+
ldc, group_count, group_size, {});
332332
__FORCE_MKL_FLUSH__(status);
333333
return 0;
334334
}
@@ -346,7 +346,7 @@ extern "C" int onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tran
346346
(const float **)&a[0], lda,
347347
(const float **)&b[0], ldb,
348348
beta, &c[0], ldc,
349-
group_count, group_size);
349+
group_count, group_size, {});
350350
__FORCE_MKL_FLUSH__(status);
351351
return 0;
352352
}
@@ -364,7 +364,7 @@ extern "C" int onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose tran
364364
(const double **)&a[0], lda,
365365
(const double **)&b[0], ldb,
366366
beta, &c[0], ldc,
367-
group_count, group_size);
367+
group_count, group_size, {});
368368
__FORCE_MKL_FLUSH__(status);
369369
return 0;
370370
}
@@ -386,7 +386,7 @@ extern "C" int onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose tran
386386
ldb,
387387
reinterpret_cast<std::complex<float> *>(beta),
388388
reinterpret_cast<std::complex<float> **>(&c[0]), ldc,
389-
group_count, group_size);
389+
group_count, group_size, {});
390390
__FORCE_MKL_FLUSH__(status);
391391
return 0;
392392
}
@@ -409,7 +409,7 @@ extern "C" int onemklZgemmBatched(syclQueue_t device_queue, onemklTranspose tran
409409
ldb,
410410
reinterpret_cast<std::complex<double> *>(beta),
411411
reinterpret_cast<std::complex<double> **>(&c[0]), ldc,
412-
group_count, group_size);
412+
group_count, group_size, {});
413413
__FORCE_MKL_FLUSH__(status);
414414
return 0;
415415
}
@@ -426,7 +426,7 @@ extern "C" int onemklStrsmBatched(syclQueue_t device_queue, onemklSide left_righ
426426
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
427427
&trsmInfo.m_transa[0], &trsmInfo.m_unitdiag[0],
428428
m, n, alpha, (const float **)&a[0], lda,
429-
&b[0], ldb, group_count, group_size);
429+
&b[0], ldb, group_count, group_size, {});
430430
__FORCE_MKL_FLUSH__(status);
431431
return 0;
432432
}
@@ -444,7 +444,7 @@ extern "C" int onemklDtrsmBatched(syclQueue_t device_queue, onemklSide left_righ
444444
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
445445
&trsmInfo.m_transa[0], &trsmInfo.m_unitdiag[0],
446446
m, n, alpha, (const double **)&a[0], lda, &b[0],
447-
ldb, group_count, group_size);
447+
ldb, group_count, group_size, {});
448448
__FORCE_MKL_FLUSH__(status);
449449
return 0;
450450
}
@@ -464,7 +464,7 @@ extern "C" int onemklCtrsmBatched(syclQueue_t device_queue, onemklSide left_righ
464464
m, n, reinterpret_cast<std::complex<float> *>(alpha),
465465
reinterpret_cast<const std::complex<float> **>(&a[0]),
466466
lda, reinterpret_cast<std::complex<float> **>(&b[0]),
467-
ldb, group_count, group_size);
467+
ldb, group_count, group_size, {});
468468
__FORCE_MKL_FLUSH__(status);
469469
return 0;
470470
}
@@ -484,7 +484,7 @@ extern "C" int onemklZtrsmBatched(syclQueue_t device_queue, onemklSide left_righ
484484
m, n, reinterpret_cast<std::complex<double> *>(alpha),
485485
reinterpret_cast<const std::complex<double> **>(&a[0]),
486486
lda, reinterpret_cast<std::complex<double> **>(&b[0]),
487-
ldb, group_count, group_size);
487+
ldb, group_count, group_size, {});
488488
__FORCE_MKL_FLUSH__(status);
489489
return 0;
490490
}
@@ -499,7 +499,7 @@ extern "C" int onemklHgemmBatchStrided(syclQueue_t device_queue, onemklTranspose
499499
reinterpret_cast<const sycl::half *>(a), lda, stridea,
500500
reinterpret_cast<const sycl::half *>(b), ldb, strideb,
501501
sycl::bit_cast<sycl::half>(beta),
502-
reinterpret_cast<sycl::half *>(c), ldc, stridec, batch_size);
502+
reinterpret_cast<sycl::half *>(c), ldc, stridec, batch_size, {});
503503
__FORCE_MKL_FLUSH__(status);
504504
return 0;
505505
}
@@ -511,7 +511,7 @@ extern "C" int onemklSgemmBatchStrided(syclQueue_t device_queue, onemklTranspose
511511
float *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
512512
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, convert(transa),
513513
convert(transb), m, n, k, alpha, a, lda, stridea,
514-
b, ldb, strideb, beta, c, ldc, stridec, batch_size);
514+
b, ldb, strideb, beta, c, ldc, stridec, batch_size, {});
515515
__FORCE_MKL_FLUSH__(status);
516516
return 0;
517517
}
@@ -523,7 +523,7 @@ extern "C" int onemklDgemmBatchStrided(syclQueue_t device_queue, onemklTranspose
523523
double *c, int64_t ldc, int64_t stridec, int64_t batch_size) {
524524
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, convert(transa),
525525
convert(transb), m, n, k, alpha, a, lda, stridea,
526-
b, ldb, strideb, beta, c, ldc, stridec, batch_size);
526+
b, ldb, strideb, beta, c, ldc, stridec, batch_size, {});
527527
__FORCE_MKL_FLUSH__(status);
528528
return 0;
529529
}
@@ -540,7 +540,7 @@ extern "C" int onemklCgemmBatchStrided(syclQueue_t device_queue, onemklTranspose
540540
reinterpret_cast<const std::complex<float> *>(b),
541541
ldb, strideb, beta,
542542
reinterpret_cast<std::complex<float> *>(c),
543-
ldc, stridec, batch_size);
543+
ldc, stridec, batch_size, {});
544544
__FORCE_MKL_FLUSH__(status);
545545
return 0;
546546
}
@@ -557,7 +557,7 @@ extern "C" int onemklZgemmBatchStrided(syclQueue_t device_queue, onemklTranspose
557557
reinterpret_cast<const std::complex<double> *>(b),
558558
ldb, strideb, beta,
559559
reinterpret_cast<std::complex<double> *>(c),
560-
ldc, stridec, batch_size);
560+
ldc, stridec, batch_size, {});
561561
__FORCE_MKL_FLUSH__(status);
562562
return 0;
563563
}
@@ -572,7 +572,7 @@ extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA,
572572
reinterpret_cast<const sycl::half *>(A), lda,
573573
reinterpret_cast<const sycl::half *>(B), ldb,
574574
sycl::bit_cast<sycl::half>(beta),
575-
reinterpret_cast<sycl::half *>(C), ldc);
575+
reinterpret_cast<sycl::half *>(C), ldc, {});
576576
__FORCE_MKL_FLUSH__(status);
577577
return 0;
578578
}
@@ -583,7 +583,7 @@ extern "C" int onemklHdot(syclQueue_t device_queue, int64_t n,
583583
auto status = oneapi::mkl::blas::column_major::dot(device_queue->val, n,
584584
reinterpret_cast<const sycl::half *>(x),
585585
incx, reinterpret_cast<const sycl::half *>(y),
586-
incy, reinterpret_cast<sycl::half *>(result));
586+
incy, reinterpret_cast<sycl::half *>(result), {});
587587
__FORCE_MKL_FLUSH__(status);
588588
return 0;
589589
}
@@ -593,15 +593,15 @@ extern "C" int onemklHaxpy(syclQueue_t device_queue, int64_t n, uint16_t alpha,
593593
auto status = oneapi::mkl::blas::column_major::axpy(device_queue->val, n,
594594
sycl::bit_cast<sycl::half>(alpha),
595595
reinterpret_cast<const sycl::half *>(x),
596-
incx, reinterpret_cast<sycl::half *>(y), incy);
596+
incx, reinterpret_cast<sycl::half *>(y), incy, {});
597597
__FORCE_MKL_FLUSH__(status);
598598
return 0;
599599
}
600600

601601
extern "C" int onemklHscal(syclQueue_t device_queue, int64_t n, uint16_t alpha,
602602
short *x, int64_t incx) {
603603
auto status = oneapi::mkl::blas::column_major::scal(device_queue->val, n, sycl::bit_cast<sycl::half>(alpha),
604-
reinterpret_cast<sycl::half *>(x), incx);
604+
reinterpret_cast<sycl::half *>(x), incx, {});
605605
__FORCE_MKL_FLUSH__(status);
606606
return 0;
607607
}
@@ -610,7 +610,7 @@ extern "C" int onemklHnrm2(syclQueue_t device_queue, int64_t n, const short *x,
610610
int64_t incx, short *result) {
611611
auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n,
612612
reinterpret_cast<const sycl::half *>(x), incx,
613-
reinterpret_cast<sycl::half *>(result));
613+
reinterpret_cast<sycl::half *>(result), {});
614614
__FORCE_MKL_FLUSH__(status);
615615
return 0;
616616
}

0 commit comments

Comments
 (0)