Skip to content

Commit 457e020

Browse files
authored
Update oneAPI.jl for the release 2024.2.0 (#446)
1 parent 2c0299f commit 457e020

File tree

7 files changed

+4055
-3429
lines changed

7 files changed

+4055
-3429
lines changed

deps/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
1212
oneAPI_Support_Headers_jll = "24f86df5-245d-5634-a4cc-32433d9800b3"
1313

1414
[compat]
15-
oneAPI_Support_Headers_jll = "=2024.1.0"
15+
oneAPI_Support_Headers_jll = "=2024.2.0"

deps/build_local.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ if !isfile(joinpath(conda_dir, "condarc-julia.yml"))
4040
mkpath(joinpath(conda_dir, "conda-meta"))
4141
touch(joinpath(conda_dir, "conda-meta", "history"))
4242
end
43-
Conda.add(["dpcpp_linux-64=2024.1.0", "mkl-devel-dpcpp=2024.1.0"], conda_dir;
43+
Conda.add(["dpcpp_linux-64=2024.2.0", "mkl-devel-dpcpp=2024.2.0"], conda_dir;
4444
channel="intel")
4545

4646
Conda.list(conda_dir)

deps/generate_interfaces.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ using oneAPI_Support_Headers_jll
33
include("generate_helpers.jl")
44

55
blas = [joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "blas", "buffer_decls.hpp")]
6-
lapack = [joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "lapack", "lapack.hpp")]
6+
lapack = [joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "lapack", "lapack.hpp"),
7+
joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "lapack", "scratchpad.hpp")]
78
sparse = [joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "spblas", "sparse_structures.hpp"),
89
joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "spblas", "sparse_auxiliary.hpp"),
910
joinpath(oneAPI_Support_Headers_jll.artifact_dir, "include", "oneapi", "mkl", "spblas", "sparse_operations.hpp")]
@@ -64,13 +65,15 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
6465
occursin("get_matmat_data", header) && continue # SPARSE routine
6566
occursin("matmat(", header) && continue # SPARSE routine
6667
occursin("gemm_bias", header) && continue # BLAS routine
67-
occursin("heevx", header) && continue # LAPACK routine (compiler bug)
68-
occursin("hegvx", header) && continue # LAPACK routine (compiler bug)
6968
occursin("getri_batch", header) && occursin("ldainv", header) && continue # LAPACK routine
7069

7170
# Check if the routine is a template
7271
template = occursin("template", header)
7372
if template
73+
header = replace(header, "template <typename fp, oneapi::mkl::lapack::internal::is_floating_point<fp> = nullptr> " => "")
74+
header = replace(header, "template <typename fp, oneapi::mkl::lapack::internal::is_real_floating_point<fp> = nullptr> " => "")
75+
header = replace(header, "template <typename fp, oneapi::mkl::lapack::internal::is_complex_floating_point<fp> = nullptr> " => "")
76+
7477
header = replace(header, "template <typename data_t, oneapi::mkl::lapack::internal::is_floating_point<data_t> = nullptr>" => "")
7578
header = replace(header, "template <typename data_t, oneapi::mkl::lapack::internal::is_real_floating_point<data_t> = nullptr>" => "")
7679
header = replace(header, "template <typename data_t, oneapi::mkl::lapack::internal::is_complex_floating_point<data_t> = nullptr>" => "")
@@ -99,6 +102,7 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
99102

100103
# Replace the types
101104
header = replace(header, "sycl::queue &queue" => "syclQueue_t device_queue")
105+
header = replace(header, "sycl::queue& queue" => "syclQueue_t device_queue")
102106

103107
if library ("blas", "sparse")
104108
header = replace(header, "compute_mode mode = MKL_BLAS_COMPUTE_MODE" => "")
@@ -214,16 +218,21 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
214218
copy_header = header
215219
copy_header = replace(copy_header, "typename fp_type::value_type" => version_types_header[blas_version])
216220
copy_header = replace(copy_header, "fp_type" => version_types_header[blas_version])
221+
copy_header = replace(copy_header, "fp" => version_types_header[blas_version])
217222
copy_header = replace(copy_header, name_routine => "onemkl$(blas_version)$(name_routine)")
223+
if name_routine ("heevx_scratchpad_size", "hegvx_scratchpad_size")
224+
copy_header = replace(copy_header, "typename float _Complex::value_type" => "float")
225+
copy_header = replace(copy_header, "typename double _Complex::value_type" => "double")
226+
end
218227
if occursin("batch", name_routine) && !occursin("*", header)
219228
copy_header = replace(copy_header, "_batch" => "_batch_strided")
220229
end
221230
push!(signatures, (copy_header, name_routine, blas_version, type_routine, template))
222231
end
223232
else
224233
if isempty(list_versions)
225-
suffix = ""
226234
# The routine "optimize_trsm" has two versions.
235+
suffix = ""
227236
(name_routine == "optimize_trsm") && occursin("columns", header) && (suffix = "_advanced")
228237
name_routine ("set_csr_data", "set_coo_data") && occursin("int64_t", header) && (suffix = "_64")
229238
occursin("batch", name_routine) && !occursin("**", header) && (suffix = "_strided")
@@ -281,6 +290,13 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
281290
copy_header = replace(copy_header, "_batch" => "_batch_strided")
282291
end
283292
if library == "blas"
293+
# Out-of-place variants of trsm and trmm
294+
if occursin("trsm", header) && occursin("ldc", header)
295+
copy_header = replace(copy_header, "trsm" => "trsm_variant")
296+
end
297+
if occursin("trmm", header) && occursin("ldc", header)
298+
copy_header = replace(copy_header, "trmm" => "trmm_variant")
299+
end
284300
copy_header = replace(copy_header, "compute_mode mode," => "")
285301
copy_header = replace(copy_header, ", compute_mode mode)" => ")")
286302
copy_header = replace(copy_header, "value_or_pointer<float _Complex>" => "float _Complex")
@@ -380,11 +396,14 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
380396
parameters = replace(parameters, ", double " => ", ")
381397
parameters = replace(parameters, ", **" => ", ")
382398
parameters = replace(parameters, ", *" => ", ")
383-
384399
parameters = replace(parameters, "onemklTranspose *trans," => "convert(trans, group_count),")
400+
parameters = replace(parameters, "onemklTranspose* trans," => "convert(trans, group_count),")
385401
parameters = replace(parameters, "onemklUplo *uplo," => "convert(uplo, group_count),")
402+
parameters = replace(parameters, "onemklUplo* uplo," => "convert(uplo, group_count),")
386403
parameters = replace(parameters, "onemklDiag *diag," => "convert(diag, group_count),")
404+
parameters = replace(parameters, "onemklDiag* diag," => "convert(diag, group_count),")
387405
parameters = replace(parameters, "onemklSide *side," => "convert(side, group_count),")
406+
parameters = replace(parameters, "onemklSide* side," => "convert(side, group_count),")
388407

389408
for type in ("onemklTranspose", "onemklSide", "onemklUplo", "onemklDiag", "onemklGenerate",
390409
"onemklLayout", "onemklJob", "onemklJobsvd", "onemklCompz", "onemklRangev",

0 commit comments

Comments
 (0)