@@ -3,7 +3,8 @@ using oneAPI_Support_Headers_jll
3
3
include (" generate_helpers.jl" )
4
4
5
5
blas = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " blas" , " buffer_decls.hpp" )]
6
- lapack = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " lapack" , " lapack.hpp" )]
6
+ lapack = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " lapack" , " lapack.hpp" ),
7
+ joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " lapack" , " scratchpad.hpp" )]
7
8
sparse = [joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_structures.hpp" ),
8
9
joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_auxiliary.hpp" ),
9
10
joinpath (oneAPI_Support_Headers_jll. artifact_dir, " include" , " oneapi" , " mkl" , " spblas" , " sparse_operations.hpp" )]
@@ -64,13 +65,15 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
64
65
occursin (" get_matmat_data" , header) && continue # SPARSE routine
65
66
occursin (" matmat(" , header) && continue # SPARSE routine
66
67
occursin (" gemm_bias" , header) && continue # BLAS routine
67
- occursin (" heevx" , header) && continue # LAPACK routine (compiler bug)
68
- occursin (" hegvx" , header) && continue # LAPACK routine (compiler bug)
69
68
occursin (" getri_batch" , header) && occursin (" ldainv" , header) && continue # LAPACK routine
70
69
71
70
# Check if the routine is a template
72
71
template = occursin (" template" , header)
73
72
if template
73
+ header = replace (header, " template <typename fp, oneapi::mkl::lapack::internal::is_floating_point<fp> = nullptr> " => " " )
74
+ header = replace (header, " template <typename fp, oneapi::mkl::lapack::internal::is_real_floating_point<fp> = nullptr> " => " " )
75
+ header = replace (header, " template <typename fp, oneapi::mkl::lapack::internal::is_complex_floating_point<fp> = nullptr> " => " " )
76
+
74
77
header = replace (header, " template <typename data_t, oneapi::mkl::lapack::internal::is_floating_point<data_t> = nullptr>" => " " )
75
78
header = replace (header, " template <typename data_t, oneapi::mkl::lapack::internal::is_real_floating_point<data_t> = nullptr>" => " " )
76
79
header = replace (header, " template <typename data_t, oneapi::mkl::lapack::internal::is_complex_floating_point<data_t> = nullptr>" => " " )
@@ -99,6 +102,7 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
99
102
100
103
# Replace the types
101
104
header = replace (header, " sycl::queue &queue" => " syclQueue_t device_queue" )
105
+ header = replace (header, " sycl::queue& queue" => " syclQueue_t device_queue" )
102
106
103
107
if library ∈ (" blas" , " sparse" )
104
108
header = replace (header, " compute_mode mode = MKL_BLAS_COMPUTE_MODE" => " " )
@@ -214,16 +218,21 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
214
218
copy_header = header
215
219
copy_header = replace (copy_header, " typename fp_type::value_type" => version_types_header[blas_version])
216
220
copy_header = replace (copy_header, " fp_type" => version_types_header[blas_version])
221
+ copy_header = replace (copy_header, " fp" => version_types_header[blas_version])
217
222
copy_header = replace (copy_header, name_routine => " onemkl$(blas_version)$(name_routine) " )
223
+ if name_routine ∈ (" heevx_scratchpad_size" , " hegvx_scratchpad_size" )
224
+ copy_header = replace (copy_header, " typename float _Complex::value_type" => " float" )
225
+ copy_header = replace (copy_header, " typename double _Complex::value_type" => " double" )
226
+ end
218
227
if occursin (" batch" , name_routine) && ! occursin (" *" , header)
219
228
copy_header = replace (copy_header, " _batch" => " _batch_strided" )
220
229
end
221
230
push! (signatures, (copy_header, name_routine, blas_version, type_routine, template))
222
231
end
223
232
else
224
233
if isempty (list_versions)
225
- suffix = " "
226
234
# The routine "optimize_trsm" has two versions.
235
+ suffix = " "
227
236
(name_routine == " optimize_trsm" ) && occursin (" columns" , header) && (suffix = " _advanced" )
228
237
name_routine ∈ (" set_csr_data" , " set_coo_data" ) && occursin (" int64_t" , header) && (suffix = " _64" )
229
238
occursin (" batch" , name_routine) && ! occursin (" **" , header) && (suffix = " _strided" )
@@ -281,6 +290,13 @@ function generate_headers(library::String, filename::Vector{String}, output::Str
281
290
copy_header = replace (copy_header, " _batch" => " _batch_strided" )
282
291
end
283
292
if library == " blas"
293
+ # Out-of-place variants of trsm and trmm
294
+ if occursin (" trsm" , header) && occursin (" ldc" , header)
295
+ copy_header = replace (copy_header, " trsm" => " trsm_variant" )
296
+ end
297
+ if occursin (" trmm" , header) && occursin (" ldc" , header)
298
+ copy_header = replace (copy_header, " trmm" => " trmm_variant" )
299
+ end
284
300
copy_header = replace (copy_header, " compute_mode mode," => " " )
285
301
copy_header = replace (copy_header, " , compute_mode mode)" => " )" )
286
302
copy_header = replace (copy_header, " value_or_pointer<float _Complex>" => " float _Complex" )
@@ -380,11 +396,14 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
380
396
parameters = replace (parameters, " , double " => " , " )
381
397
parameters = replace (parameters, " , **" => " , " )
382
398
parameters = replace (parameters, " , *" => " , " )
383
-
384
399
parameters = replace (parameters, " onemklTranspose *trans," => " convert(trans, group_count)," )
400
+ parameters = replace (parameters, " onemklTranspose* trans," => " convert(trans, group_count)," )
385
401
parameters = replace (parameters, " onemklUplo *uplo," => " convert(uplo, group_count)," )
402
+ parameters = replace (parameters, " onemklUplo* uplo," => " convert(uplo, group_count)," )
386
403
parameters = replace (parameters, " onemklDiag *diag," => " convert(diag, group_count)," )
404
+ parameters = replace (parameters, " onemklDiag* diag," => " convert(diag, group_count)," )
387
405
parameters = replace (parameters, " onemklSide *side," => " convert(side, group_count)," )
406
+ parameters = replace (parameters, " onemklSide* side," => " convert(side, group_count)," )
388
407
389
408
for type in (" onemklTranspose" , " onemklSide" , " onemklUplo" , " onemklDiag" , " onemklGenerate" ,
390
409
" onemklLayout" , " onemklJob" , " onemklJobsvd" , " onemklCompz" , " onemklRangev" ,
0 commit comments