@@ -343,14 +343,30 @@ def ty_to_cpp(ty):
343
343
"u16" : "uint16_t" ,
344
344
"u32" : "uint32_t" ,
345
345
"u64" : "uint64_t" ,
346
- "fp16" : "float " ,
347
- "bf16" : "float " ,
348
- "fp32" : "float " ,
349
- "f32" : "float " ,
346
+ "fp16" : "double " ,
347
+ "bf16" : "double " ,
348
+ "fp32" : "double " ,
349
+ "f32" : "double " ,
350
350
"fp64" : "double" ,
351
351
}[ty ]
352
352
353
353
354
+ FLOAT_STORAGE_TYPE = {
355
+ "fp16" : "uint16_t" ,
356
+ "bf16" : "uint16_t" ,
357
+ "fp32" : "uint32_t" ,
358
+ "f32" : "uint32_t" ,
359
+ "fp64" : "uint64_t" ,
360
+ }
361
+ FLOAT_PACK_FUNCTION = {
362
+ "fp16" : "pack_fp16" ,
363
+ "bf16" : "pack_bf16" ,
364
+ "fp32" : "pack_fp32" ,
365
+ "f32" : "pack_fp32" ,
366
+ "fp64" : "pack_fp64" ,
367
+ }
368
+
369
+
354
370
def make_launcher (constants , signature ):
355
371
356
372
def _serialize_signature (sig ):
@@ -379,7 +395,6 @@ def format_of(ty):
379
395
if ty == "void*" :
380
396
return "O"
381
397
return {
382
- "float" : "f" ,
383
398
"double" : "d" ,
384
399
"long" : "l" ,
385
400
"int8_t" : "b" ,
@@ -401,11 +416,21 @@ def format_of(ty):
401
416
402
417
# Record the end of regular arguments;
403
418
# subsequent arguments are architecture-specific descriptors.
404
- arg_decls = ', ' .join (f"{ ty_to_cpp (ty )} arg{ i } " for i , ty in signature .items () if ty != "constexpr" )
419
+ arg_decl_list = []
420
+ for i , ty in signature .items ():
421
+ if ty == "constexpr" :
422
+ continue
423
+ if ty in FLOAT_STORAGE_TYPE :
424
+ arg_decl_list .append (f"{ FLOAT_STORAGE_TYPE [ty ]} arg{ i } " )
425
+ else :
426
+ arg_decl_list .append (f"{ ty_to_cpp (ty )} arg{ i } " )
427
+ arg_decls = ', ' .join (arg_decl_list )
405
428
internal_args_list = []
406
429
for i , ty in signature .items ():
407
430
if ty [0 ] == "*" :
408
431
internal_args_list .append (f"ptr_info{ i } .dev_ptr" )
432
+ elif ty in FLOAT_STORAGE_TYPE :
433
+ internal_args_list .append (f"_arg{ i } _storage" )
409
434
elif ty != "constexpr" :
410
435
internal_args_list .append (f"_arg{ i } " )
411
436
@@ -416,6 +441,11 @@ def format_of(ty):
416
441
for i , ty in signature .items ()
417
442
if ty [0 ] == "*"
418
443
]
444
+ float_storage_decls = [
445
+ f"{ FLOAT_STORAGE_TYPE [ty ]} _arg{ i } _storage = { FLOAT_PACK_FUNCTION [ty ]} (_arg{ i } );"
446
+ for i , ty in signature .items ()
447
+ if ty in FLOAT_STORAGE_TYPE
448
+ ]
419
449
params = [f"&arg{ i } " for i , ty in signature .items () if ty != "constexpr" ]
420
450
params .append ("&global_scratch" )
421
451
num_params = len (params )
@@ -424,6 +454,7 @@ def format_of(ty):
424
454
params_decl = f"void *params[] = {{ { ', ' .join (params )} }};"
425
455
src = f"""
426
456
#include <cstddef>
457
+ #include <Python.h>
427
458
#include <string>
428
459
#include <iostream>
429
460
#include <iomanip>
@@ -564,6 +595,32 @@ def format_of(ty):
564
595
}}
565
596
// end sycl
566
597
598
+ static uint16_t pack_fp16(double f) {{
599
+ uint16_t result;
600
+ // from https://github.com/python/pythoncapi-compat
601
+ #if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
602
+ _PyFloat_Pack2(f, (unsigned char *)&result, 1);
603
+ #else
604
+ PyFloat_Pack2(f, (void*)&result, 1);
605
+ #endif
606
+ return result;
607
+ }}
608
+
609
+ static uint16_t pack_bf16(double f) {{
610
+ float f32 = (float)f;
611
+ uint32_t u32 = *(uint32_t*)&f32;
612
+ return (uint16_t)(u32 >> 16);
613
+ }}
614
+
615
+ static uint32_t pack_fp32(double f) {{
616
+ float f32 = (float)f;
617
+ return *(uint32_t*)&f32;
618
+ }}
619
+
620
+ static uint64_t pack_fp64(double f) {{
621
+ return *(uint64_t*)&f;
622
+ }}
623
+
567
624
extern "C" EXPORT_FUNC PyObject* launch(PyObject* args) {{
568
625
int gridX, gridY, gridZ;
569
626
void* global_scratch = nullptr;
@@ -625,6 +682,7 @@ def format_of(ty):
625
682
sycl::kernel kernel = *kernel_ptr;
626
683
627
684
{ newline .join (ptr_decls )}
685
+ { newline .join (float_storage_decls )}
628
686
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel, global_scratch{ ',' + ', ' .join (internal_args_list ) if len (internal_args_list ) > 0 else '' } );
629
687
if (PyErr_Occurred()) {{
630
688
return NULL;
0 commit comments