@@ -9,16 +9,12 @@ dict_version = Dict{Int, Char}(1 => 'S', 2 => 'D', 3 => 'C', 4 => 'Z')
9
9
version_types = Dict {Char, String} (' S' => " float" ,
10
10
' D' => " double" ,
11
11
' C' => " std::complex<float>" ,
12
- ' Z' => " std::complex<double>" ,
13
- ' I' => " int32_t" ,
14
- ' L' => " int64_t" )
12
+ ' Z' => " std::complex<double>" )
15
13
16
14
version_types_header = Dict {Char, String} (' S' => " float" ,
17
15
' D' => " double" ,
18
16
' C' => " float _Complex" ,
19
- ' Z' => " double _Complex" ,
20
- ' I' => " int32_t" ,
21
- ' L' => " int64_t" )
17
+ ' Z' => " double _Complex" )
22
18
23
19
function generate_headers (library:: String , filename:: String , output:: String )
24
20
routines = Dict {String,Int} ()
@@ -49,8 +45,10 @@ function generate_headers(library::String, filename::String, output::String)
49
45
occursin (" heevx" , header) && continue # LAPACK routine
50
46
occursin (" hegvx" , header) && continue # LAPACK routine
51
47
occursin (" (matrix_handle_t handle" , header) && continue # SPARSE routine
48
+ occursin (" update_diagonal_values" , header) && continue # SPARSE routine
49
+ occursin (" get_matmat_data" , header) && continue # SPARSE routine
50
+ occursin (" matmat(" , header) && continue # SPARSE routine
52
51
occursin (" gemvdot" , header) && continue # SPARSE routine
53
- occursin (" matmat" , header) && continue # SPARSE routine
54
52
55
53
# Check if the routine is a template
56
54
template = occursin (" template" , header)
@@ -104,8 +102,10 @@ function generate_headers(library::String, filename::String, output::String)
104
102
header = replace (header, " sycl::buffer<double, 1> &" => " double *" )
105
103
header = replace (header, " sycl::buffer<std::complex<float>, 1> &" => " float _Complex *" )
106
104
header = replace (header, " sycl::buffer<std::complex<double>, 1> &" => " double _Complex *" )
105
+ header = replace (header, " sycl::buffer<std::uint8_t, 1> *" => " uint8_t *" )
107
106
header = replace (header, " sycl::buffer<int32_t, 1> &" => " int32_t *" )
108
107
header = replace (header, " sycl::buffer<int64_t, 1> &" => " int64_t *" )
108
+ header = replace (header, " sycl::buffer<int64_t, 1> *" => " int64_t *" )
109
109
110
110
header = replace (header, " template <>\n " => " " )
111
111
header = replace (header, " <std::complex<float>>" => " " )
@@ -129,6 +129,7 @@ function generate_headers(library::String, filename::String, output::String)
129
129
header = replace (header, " oneapi::mkl::layout" => " onemklLayout" )
130
130
header = replace (header, " oneapi::mkl::index" => " onemklIndex" )
131
131
header = replace (header, " oneapi::mkl::property" => " onemklProperty" )
132
+ header = replace (header, " sparse::matmat_descr_t" => " matmat_descr_t" )
132
133
133
134
# Sanitize the header
134
135
header = replace (header, " \\ " => " " )
@@ -203,10 +204,6 @@ function generate_headers(library::String, filename::String, output::String)
203
204
end
204
205
end
205
206
version = ' X'
206
- if library == " sparse"
207
- version = occursin (" int32_t" , header) ? ' I' : version
208
- version = occursin (" int64_t" , header) ? ' L' : version
209
- end
210
207
version = occursin (" double" , header) ? ' D' : version
211
208
version = occursin (" float" , header) ? ' S' : version
212
209
version = occursin (" float _Complex" , header) ? ' C' : version
@@ -218,6 +215,7 @@ function generate_headers(library::String, filename::String, output::String)
218
215
versions = (' S' , ' D' , ' C' , ' Z' )
219
216
mapreduce (x -> startswith (name_routine, x), | , [" or" , " sy" ]) && ! startswith (name_routine, " sytrf" ) && (versions = (' S' , ' D' ))
220
217
mapreduce (x -> startswith (name_routine, x), | , [" un" , " he" ]) && (versions = (' C' , ' Z' ))
218
+ (name_routine == " gesvd_scratchpad_size" ) && (routines[name_routine] > 1 ) && continue
221
219
routines[name_routine] = routines[name_routine] - 1 + length (versions)
222
220
for blas_version in versions
223
221
copy_header = header
@@ -229,12 +227,14 @@ function generate_headers(library::String, filename::String, output::String)
229
227
end
230
228
else
231
229
if isempty (list_versions)
230
+ suffix = " "
232
231
if name_routine == " set_csr_data"
233
- occursin (" int32_t" , header) && (version = " I" * version)
234
- occursin (" int64_t" , header) && (version = " L" * version)
232
+ occursin (" int64_t" , header) && (suffix = " _64" )
233
+ end
234
+ header = replace (header, " $(name_routine) (" => " onemkl$(version)$(name_routine)$(suffix) (" )
235
+ if name_routine ∉ (" init_matrix_handle" , " init_matmat_descr" , " release_matmat_descr" , " set_matmat_data" )
236
+ header = replace (header, " void onemkl" => " int onemkl" )
235
237
end
236
- header = replace (header, name_routine => " onemkl$(version)$(name_routine) " )
237
- header = replace (header, " void onemkl" => " int onemkl" )
238
238
if library == " sparse"
239
239
if occursin (" std::complex" , header)
240
240
(version == ' C' ) && (header = replace (header, " std::complex " => " float _Complex " ))
@@ -247,6 +247,9 @@ function generate_headers(library::String, filename::String, output::String)
247
247
header = replace (header, " layout " => " onemklLayout " )
248
248
header = replace (header, " index_base " => " onemklIndex " )
249
249
header = replace (header, " property " => " onemklProperty " )
250
+ header = replace (header, " sparse::matrix_view_descr " => " onemklMatrixView " )
251
+ header = replace (header, " matrix_view_descr " => " onemklMatrixView " )
252
+ header = replace (header, " sparse::matmat_request " => " onemklMatmatRequest " )
250
253
header = replace (header, name_routine => " sparse_" * name_routine)
251
254
end
252
255
push! (signatures, (header, name_routine, version, template))
@@ -280,7 +283,7 @@ function generate_headers(library::String, filename::String, output::String)
280
283
# Check the number of methods
281
284
blacklist = String[]
282
285
for name_routine in keys (routines)
283
- if (routines[name_routine] > 5 ) && (library != " sparse " )
286
+ if (routines[name_routine] > 5 ) && (name_routine != " set_csr_data " )
284
287
@warn " The routine $(name_routine) has more than 4 methods and will not be interfaced."
285
288
push! (blacklist, name_routine)
286
289
end
@@ -331,7 +334,10 @@ function generate_cpp(library::String, filename::String, output::String)
331
334
parameters = replace (parameters, " syclQueue_t device_queue" => " device_queue->val" )
332
335
parameters = replace (parameters, " int32_t " => " " )
333
336
parameters = replace (parameters, " int64_t " => " " )
334
- parameters = replace (parameters, " matrix_handle_t " => " " )
337
+ parameters = replace (parameters, " matrix_handle_t *" => " (oneapi::mkl::sparse::matrix_handle_t*) " )
338
+ parameters = replace (parameters, " matrix_handle_t " => " (oneapi::mkl::sparse::matrix_handle_t) " )
339
+ parameters = replace (parameters, " matmat_descr_t *" => " (oneapi::mkl::sparse::matmat_descr_t*) " )
340
+ parameters = replace (parameters, " matmat_descr_t " => " (oneapi::mkl::sparse::matmat_descr_t) " )
335
341
parameters = replace (parameters, " float _Complex *" => " reinterpret_cast<std::complex<float> *>" )
336
342
parameters = replace (parameters, " double _Complex *" => " reinterpret_cast<std::complex<double> *>" )
337
343
parameters = replace (parameters, " float _Complex " => " static_cast<std::complex<float> >" )
@@ -342,13 +348,15 @@ function generate_cpp(library::String, filename::String, output::String)
342
348
parameters = replace (parameters, " , double " => " , " )
343
349
parameters = replace (parameters, " , *" => " , " )
344
350
345
- for type in (" onemklTranspose" , " onemklSide" , " onemklUplo" , " onemklDiag" , " onemklGenerate" ,
346
- " onemklJob" , " onemklJobsvd" , " onemklCompz" , " onemklRangev" , " onemklIndex" , " onemklProperty" )
347
- parameters = replace (parameters, Regex (" $type ([a-z_]+)," ) => SubstitutionString (" convert(\\ 1)," ))
348
- parameters = replace (parameters, Regex (" , $type ([a-z_]+)" ) => SubstitutionString (" , convert(\\ 1)" ))
351
+ for type in (" onemklTranspose" , " onemklSide" , " onemklUplo" , " onemklDiag" ,
352
+ " onemklGenerate" , " onemklLayout" , " onemklJob" , " onemklJobsvd" ,
353
+ " onemklCompz" , " onemklRangev" , " onemklIndex" , " onemklProperty" ,
354
+ " onemklMatrixView" , " onemklMatmatRequest" )
355
+ parameters = replace (parameters, Regex (" $type ([A-Za-z_]+)," ) => SubstitutionString (" convert(\\ 1)," ))
356
+ parameters = replace (parameters, Regex (" , $type ([A-Za-z_]+)" ) => SubstitutionString (" , convert(\\ 1)" ))
349
357
end
350
- parameters = replace (parameters, r" >([a-z ]+)" => s " >(\1 )" )
351
- parameters = replace (parameters, r" \* >([a-z ]+)" => s " *>(\1 )" )
358
+ parameters = replace (parameters, r" >([A-Za-z_ ]+)" => s " >(\1 )" )
359
+ parameters = replace (parameters, r" \* >([A-Za-z_ ]+)" => s " *>(\1 )" )
352
360
353
361
variant = " "
354
362
if library == " blas"
@@ -361,13 +369,17 @@ function generate_cpp(library::String, filename::String, output::String)
361
369
! occursin (" scratchpad_size" , name) && write (oneapi_cpp, " auto status = oneapi::mkl::$library ::$variant$name <$type >($parameters );\n " )
362
370
occursin (" scratchpad_size" , name) && write (oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library ::$variant$name <$type >($parameters );\n " )
363
371
else
364
- write (oneapi_cpp, " auto status = oneapi::mkl::$library ::$variant$name ($parameters );\n " )
372
+ if name ∉ (" init_matrix_handle" , " init_matmat_descr" , " release_matmat_descr" , " set_matmat_data" )
373
+ write (oneapi_cpp, " auto status = oneapi::mkl::$library ::$variant$name ($parameters );\n " )
374
+ else
375
+ write (oneapi_cpp, " oneapi::mkl::$library ::$variant$name ($parameters );\n " )
376
+ end
365
377
end
366
378
if occursin (" scratchpad_size" , name)
367
379
write (oneapi_cpp, " return scratchpad_size;\n " )
368
380
else
369
- write (oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n " )
370
- write (oneapi_cpp, " return 0;\n " )
381
+ (name ∉ ( " init_matrix_handle " , " init_matmat_descr " , " release_matmat_descr " , " set_matmat_data " )) && write (oneapi_cpp, " __FORCE_MKL_FLUSH__(status);\n " )
382
+ (name ∉ ( " init_matrix_handle " , " init_matmat_descr " , " release_matmat_descr " , " set_matmat_data " )) && write (oneapi_cpp, " return 0;\n " )
371
383
end
372
384
write (oneapi_cpp, " }" )
373
385
write (oneapi_cpp, " \n\n " )
377
389
378
390
generate_headers (" lapack" , lapack, " onemkl_lapack.h" )
379
391
generate_headers (" blas" , blas, " onemkl_blas.h" )
380
- # generate_headers("sparse", sparse, "onemkl_sparse.h")
392
+ generate_headers (" sparse" , sparse, " onemkl_sparse.h" )
381
393
382
394
io = open (" src/onemkl.h" , " w" )
383
395
headers_prologue = read (" onemkl_prologue.h" , String)
@@ -388,16 +400,16 @@ write(io, headers_blas)
388
400
headers_lapack = read (" onemkl_lapack.h" , String)
389
401
write (io, " // LAPACK\n " )
390
402
write (io, headers_lapack)
391
- # headers_sparse = read("onemkl_sparse.h", String)
392
- # write(io, "// SPARSE\n")
393
- # write(io, headers_sparse)
403
+ headers_sparse = read (" onemkl_sparse.h" , String)
404
+ write (io, " // SPARSE\n " )
405
+ write (io, headers_sparse)
394
406
headers_epilogue = read (" onemkl_epilogue.h" , String)
395
407
write (io, headers_epilogue)
396
408
close (io)
397
409
398
410
generate_cpp (" lapack" , lapack, " onemkl_lapack.cpp" )
399
411
generate_cpp (" blas" , blas, " onemkl_blas.cpp" )
400
- # generate_cpp("sparse", sparse, "onemkl_sparse.cpp")
412
+ generate_cpp (" sparse" , sparse, " onemkl_sparse.cpp" )
401
413
402
414
io = open (" src/onemkl.cpp" , " w" )
403
415
cpp_prologue = read (" onemkl_prologue.cpp" , String)
@@ -408,9 +420,9 @@ write(io, cpp_blas)
408
420
cpp_lapack = read (" onemkl_lapack.cpp" , String)
409
421
write (io, " // LAPACK\n " )
410
422
write (io, cpp_lapack)
411
- # cpp_sparse = read("onemkl_sparse.cpp", String)
412
- # write(io, "// SPARSE\n")
413
- # write(io, cpp_sparse)
423
+ cpp_sparse = read (" onemkl_sparse.cpp" , String)
424
+ write (io, " // SPARSE\n " )
425
+ write (io, cpp_sparse)
414
426
cpp_epilogue = read (" onemkl_epilogue.cpp" , String)
415
427
write (io, cpp_epilogue)
416
428
close (io)
0 commit comments