Skip to content

Commit 94c1410

Browse files
committed
[intel] fix 'test_float_annotation' from '1748b81' by updating driver.py
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 3da2500 commit 94c1410

File tree

1 file changed

+64
-6
lines changed

1 file changed

+64
-6
lines changed

third_party/intel/backend/driver.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,30 @@ def ty_to_cpp(ty):
343343
"u16": "uint16_t",
344344
"u32": "uint32_t",
345345
"u64": "uint64_t",
346-
"fp16": "float",
347-
"bf16": "float",
348-
"fp32": "float",
349-
"f32": "float",
346+
"fp16": "double",
347+
"bf16": "double",
348+
"fp32": "double",
349+
"f32": "double",
350350
"fp64": "double",
351351
}[ty]
352352

353353

354+
FLOAT_STORAGE_TYPE = {
355+
"fp16": "uint16_t",
356+
"bf16": "uint16_t",
357+
"fp32": "uint32_t",
358+
"f32": "uint32_t",
359+
"fp64": "uint64_t",
360+
}
361+
FLOAT_PACK_FUNCTION = {
362+
"fp16": "pack_fp16",
363+
"bf16": "pack_bf16",
364+
"fp32": "pack_fp32",
365+
"f32": "pack_fp32",
366+
"fp64": "pack_fp64",
367+
}
368+
369+
354370
def make_launcher(constants, signature):
355371

356372
def _serialize_signature(sig):
@@ -379,7 +395,6 @@ def format_of(ty):
379395
if ty == "void*":
380396
return "O"
381397
return {
382-
"float": "f",
383398
"double": "d",
384399
"long": "l",
385400
"int8_t": "b",
@@ -401,11 +416,21 @@ def format_of(ty):
401416

402417
# Record the end of regular arguments;
403418
# subsequent arguments are architecture-specific descriptors.
404-
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
419+
arg_decl_list = []
420+
for i, ty in signature.items():
421+
if ty == "constexpr":
422+
continue
423+
if ty in FLOAT_STORAGE_TYPE:
424+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
425+
else:
426+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
427+
arg_decls = ', '.join(arg_decl_list)
405428
internal_args_list = []
406429
for i, ty in signature.items():
407430
if ty[0] == "*":
408431
internal_args_list.append(f"ptr_info{i}.dev_ptr")
432+
elif ty in FLOAT_STORAGE_TYPE:
433+
internal_args_list.append(f"_arg{i}_storage")
409434
elif ty != "constexpr":
410435
internal_args_list.append(f"_arg{i}")
411436

@@ -416,6 +441,11 @@ def format_of(ty):
416441
for i, ty in signature.items()
417442
if ty[0] == "*"
418443
]
444+
float_storage_decls = [
445+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
446+
for i, ty in signature.items()
447+
if ty in FLOAT_STORAGE_TYPE
448+
]
419449
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
420450
params.append("&global_scratch")
421451
num_params = len(params)
@@ -424,6 +454,7 @@ def format_of(ty):
424454
params_decl = f"void *params[] = {{ {', '.join(params)} }};"
425455
src = f"""
426456
#include <cstddef>
457+
#include <Python.h>
427458
#include <string>
428459
#include <iostream>
429460
#include <iomanip>
@@ -564,6 +595,32 @@ def format_of(ty):
564595
}}
565596
// end sycl
566597
598+
static uint16_t pack_fp16(double f) {{
599+
uint16_t result;
600+
// from https://github.com/python/pythoncapi-compat
601+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
602+
_PyFloat_Pack2(f, (unsigned char *)&result, 1);
603+
#else
604+
PyFloat_Pack2(f, (void*)&result, 1);
605+
#endif
606+
return result;
607+
}}
608+
609+
static uint16_t pack_bf16(double f) {{
610+
float f32 = (float)f;
611+
uint32_t u32 = *(uint32_t*)&f32;
612+
return (uint16_t)(u32 >> 16);
613+
}}
614+
615+
static uint32_t pack_fp32(double f) {{
616+
float f32 = (float)f;
617+
return *(uint32_t*)&f32;
618+
}}
619+
620+
static uint64_t pack_fp64(double f) {{
621+
return *(uint64_t*)&f;
622+
}}
623+
567624
extern "C" EXPORT_FUNC PyObject* launch(PyObject* args) {{
568625
int gridX, gridY, gridZ;
569626
void* global_scratch = nullptr;
@@ -625,6 +682,7 @@ def format_of(ty):
625682
sycl::kernel kernel = *kernel_ptr;
626683
627684
{newline.join(ptr_decls)}
685+
{newline.join(float_storage_decls)}
628686
sycl_kernel_launch(gridX, gridY, gridZ, num_warps, threads_per_warp, shared_memory, stream, kernel, global_scratch{',' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
629687
if (PyErr_Occurred()) {{
630688
return NULL;

0 commit comments

Comments
 (0)