Skip to content

Commit 1748b81

Browse files
authored
[FRONTEND] Fix floating points argument passing (#7439)
Fix #6176 ```python @triton.jit def kernel(ptr, val: tl.float16): tl.store(ptr, val) ptr = torch.tensor([0.0], device="cuda:0") kernel[1,](ptr, 42.0) print(ptr) # Expected: tensor([42.], device='cuda:0') # Actual: tensor([0.], device='cuda:0') ``` The issue is caused by naively passing a Python float to a Triton kernel that accepts `tl.float16` Before this PR, the conversion chain for the float input looks like the following: ``` PyArg_ParseTuple incorrectly passed to PyFloat ================> float ----------!!!----------> kernel that accepts tl.float16 ``` This PR always makes `PyArg_ParseTuple` to parse Python float to C double, and then calls [`PyFloat_Pack{2,4,8}`](https://docs.python.org/3/c-api/float.html#pack-functions) to convert it to its proper storage type. ``` PyArg_ParseTuple PyFloat_Pack{2,4,8} passed to PyFloat ==================> double ====================> uint{16,32,64}_t -------------> kernel that accepts tl.float{16,32,64} ``` The generated code snippet looks something like this (for AMD backend) ```c double _arg1; PyArg_ParseTuple(args, "piiiKKOOOOOd", ..., &_arg1); uint16_t _arg1_storage = 0; PyFloat_Pack2(_arg1, (void*)&_arg1_storage, 1); _launch(gridX, gridY, gridZ, ..., _arg1_storage); ``` - [x] Fix AMD backend - [x] Fix NVIDIA backend - [x] Add tests
1 parent 34fb64a commit 1748b81

File tree

3 files changed

+161
-18
lines changed

3 files changed

+161
-18
lines changed

python/test/unit/language/test_annotations.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import triton
44
import triton.language as tl
55
import pytest
6+
import numpy as np
67

78

89
def annotated_function(return_type=None, **arg_types):
@@ -49,3 +50,36 @@ def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr):
4950
_kernel[(1, )](x.shape[0], x.shape[0], 32)
5051
except AttributeError:
5152
pass
53+
54+
55+
# Test float annotations are properly respected
56+
@pytest.mark.parametrize(
57+
("dtype", "test_val"),
58+
[(dtype, test_val)
59+
for dtype in [tl.float16, tl.bfloat16, tl.float32, tl.float64]
60+
for test_val in [0.0, 42.0, float("inf"), float("nan")]],
61+
)
62+
def test_float_annotation(device, dtype, test_val):
63+
64+
@triton.jit
65+
@annotated_function(val=dtype)
66+
def _kernel(ptr, val):
67+
tl.static_assert(val.dtype == dtype)
68+
tl.store(ptr, val)
69+
70+
ptr = torch.empty(1, device=device, dtype=torch.float32)
71+
h = _kernel[(1, )](ptr, test_val)
72+
np.testing.assert_allclose(ptr.cpu().numpy(), [test_val], atol=1e-6)
73+
74+
# Check that the type is properly emitted in the IR
75+
if dtype == tl.float16:
76+
assert "%arg1: f16" in h.asm["ttir"]
77+
assert "arith.extf %arg1 : f16 to f32" in h.asm["ttir"]
78+
elif dtype == tl.bfloat16:
79+
assert "%arg1: bf16" in h.asm["ttir"]
80+
assert "arith.extf %arg1 : bf16 to f32" in h.asm["ttir"]
81+
elif dtype == tl.float32:
82+
assert "%arg1: f32" in h.asm["ttir"]
83+
elif dtype == tl.float64:
84+
assert "%arg1: f64" in h.asm["ttir"]
85+
assert "arith.truncf %arg1 : f64 to f32" in h.asm["ttir"]

third_party/amd/backend/driver.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,29 @@ def ty_to_cpp(ty):
163163
"u16": "uint16_t",
164164
"u32": "uint32_t",
165165
"u64": "uint64_t",
166-
"fp16": "float",
167-
"bf16": "float",
168-
"fp32": "float",
169-
"f32": "float",
166+
"fp16": "double",
167+
"bf16": "double",
168+
"fp32": "double",
169+
"f32": "double",
170170
"fp64": "double",
171171
}[ty]
172172

173173

174+
FLOAT_STORAGE_TYPE = {
175+
"fp16": "uint16_t",
176+
"bf16": "uint16_t",
177+
"fp32": "uint32_t",
178+
"f32": "uint32_t",
179+
"fp64": "uint64_t",
180+
}
181+
FLOAT_PACK_FUNCTION = {
182+
"fp16": "pack_fp16",
183+
"bf16": "pack_bf16",
184+
"fp32": "pack_fp32",
185+
"f32": "pack_fp32",
186+
"fp64": "pack_fp64",
187+
}
188+
174189
_BASE_ARGS_FORMAT = "piiiKKOOOO"
175190

176191

@@ -226,7 +241,6 @@ def format_of(ty):
226241
if ty == "constexpr":
227242
return "O"
228243
return {
229-
"float": "f",
230244
"double": "d",
231245
"long": "l",
232246
"int8_t": "b",
@@ -249,13 +263,30 @@ def format_of(ty):
249263
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
250264
# Record the end of regular arguments;
251265
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
252-
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
266+
arg_decl_list = []
267+
for i, ty in signature.items():
268+
if ty == "constexpr":
269+
continue
270+
if ty in FLOAT_STORAGE_TYPE:
271+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
272+
else:
273+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
274+
arg_decls = ', '.join(arg_decl_list)
253275
internal_args_list = []
254276
for i, ty in signature.items():
255277
if ty[0] == "*":
256278
internal_args_list.append(f"ptr_info{i}.dev_ptr")
279+
elif ty in FLOAT_STORAGE_TYPE:
280+
internal_args_list.append(f"_arg{i}_storage")
257281
elif ty != "constexpr":
258282
internal_args_list.append(f"_arg{i}")
283+
284+
float_storage_decls = [
285+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
286+
for i, ty in signature.items()
287+
if ty in FLOAT_STORAGE_TYPE
288+
]
289+
259290
libhip_path = _get_path_to_hip_runtime_dylib()
260291

261292
# generate glue code
@@ -309,9 +340,6 @@ def format_of(ty):
309340
bool initSymbolTable() {{
310341
// Use the HIP runtime library loaded into the existing process if it exits.
311342
void *lib = dlopen("libamdhip64.so", RTLD_NOLOAD);
312-
if (lib) {{
313-
// printf("[triton] chosen loaded libamdhip64.so in the process\\n");
314-
}}
315343
316344
// Otherwise, go through the list of search paths to dlopen the first HIP
317345
// driver library.
@@ -321,7 +349,6 @@ def format_of(ty):
321349
void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
322350
if (handle) {{
323351
lib = handle;
324-
// printf("[triton] chosen %s\\n", hipLibSearchPaths[i]);
325352
}}
326353
}}
327354
}}
@@ -382,7 +409,6 @@ def format_of(ty):
382409
#define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
383410
384411
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', ' + arg_decls if len(arg_decls) > 0 else ''}) {{
385-
// printf("_launch hip kernel\\n");
386412
hipDeviceptr_t global_scratch = 0;
387413
void *params[] = {{ {', '.join(params)} }};
388414
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -440,8 +466,33 @@ def format_of(ty):
440466
return ptr_info;
441467
}}
442468
469+
static uint16_t pack_fp16(double f) {{
470+
uint16_t result;
471+
// from https://github.com/python/pythoncapi-compat
472+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
473+
_PyFloat_Pack2(f, (void*)&result, 1);
474+
#else
475+
PyFloat_Pack2(f, (void*)&result, 1);
476+
#endif
477+
return result;
478+
}}
479+
480+
static uint16_t pack_bf16(double f) {{
481+
float f32 = (float)f;
482+
uint32_t u32 = *(uint32_t*)&f32;
483+
return (uint16_t)(u32 >> 16);
484+
}}
485+
486+
static uint32_t pack_fp32(double f) {{
487+
float f32 = (float)f;
488+
return *(uint32_t*)&f32;
489+
}}
490+
491+
static uint64_t pack_fp64(double f) {{
492+
return *(uint64_t*)&f;
493+
}}
494+
443495
static PyObject* launch(PyObject* self, PyObject* args) {{
444-
// printf("launch\\n");
445496
int gridX, gridY, gridZ;
446497
uint64_t _stream;
447498
uint64_t _function;
@@ -458,6 +509,8 @@ def format_of(ty):
458509
return NULL;
459510
}}
460511
512+
{' '.join(float_storage_decls)}
513+
461514
// extract kernel metadata
462515
int num_warps, num_ctas, shared_memory, clusterDimX, clusterDimY, clusterDimZ;
463516
if (!PyArg_ParseTuple(kernel_metadata, \"iiiiii\", &num_warps, &num_ctas, &shared_memory, &clusterDimX, &clusterDimY, &clusterDimZ)) {{

third_party/nvidia/backend/driver.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,30 @@ def ty_to_cpp(ty):
9494
"u16": "uint16_t",
9595
"u32": "uint32_t",
9696
"u64": "uint64_t",
97-
"fp16": "float",
98-
"bf16": "float",
99-
"fp32": "float",
100-
"f32": "float",
97+
"fp16": "double",
98+
"bf16": "double",
99+
"fp32": "double",
100+
"f32": "double",
101101
"fp64": "double",
102102
"nvTmaDesc": "CUtensorMap",
103103
}[ty]
104104

105105

106+
FLOAT_STORAGE_TYPE = {
107+
"fp16": "uint16_t",
108+
"bf16": "uint16_t",
109+
"fp32": "uint32_t",
110+
"f32": "uint32_t",
111+
"fp64": "uint64_t",
112+
}
113+
FLOAT_PACK_FUNCTION = {
114+
"fp16": "pack_fp16",
115+
"bf16": "pack_bf16",
116+
"fp32": "pack_fp32",
117+
"f32": "pack_fp32",
118+
"fp64": "pack_fp64",
119+
}
120+
106121
_BASE_ARGS_FORMAT = "iiiKKppOOOOO"
107122

108123

@@ -175,7 +190,6 @@ def format_of(ty):
175190
if ty.startswith("tensordesc"):
176191
return "O"
177192
return {
178-
"float": "f",
179193
"double": "d",
180194
"long": "l",
181195
"int8_t": "b",
@@ -201,11 +215,21 @@ def format_of(ty):
201215
args_list = ', ' + ', '.join(f"&_arg{i}" for i, ty in signature.items()) if len(signature) > 0 else ''
202216
# Record the end of regular arguments;
203217
# subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA.
204-
arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items() if ty != "constexpr")
218+
arg_decl_list = []
219+
for i, ty in signature.items():
220+
if ty == "constexpr":
221+
continue
222+
if ty in FLOAT_STORAGE_TYPE:
223+
arg_decl_list.append(f"{FLOAT_STORAGE_TYPE[ty]} arg{i}")
224+
else:
225+
arg_decl_list.append(f"{ty_to_cpp(ty)} arg{i}")
226+
arg_decls = ', '.join(arg_decl_list)
205227
internal_args_list = []
206228
for i, ty in signature.items():
207229
if ty[0] == "*":
208230
internal_args_list.append(f"ptr_info{i}.dev_ptr")
231+
elif ty in FLOAT_STORAGE_TYPE:
232+
internal_args_list.append(f"_arg{i}_storage")
209233
elif ty == "nvTmaDesc":
210234
# Note: we have to dereference the pointer
211235
internal_args_list.append(f"*tma_ptr{i}")
@@ -224,6 +248,11 @@ def format_of(ty):
224248
f"CUtensorMap* tma_ptr{i} = getTmaDesc(_arg{i}); if (!tma_ptr{i}) return NULL;" for i, ty in signature.items()
225249
if ty == "nvTmaDesc"
226250
]
251+
float_storage_decls = [
252+
f"{FLOAT_STORAGE_TYPE[ty]} _arg{i}_storage = {FLOAT_PACK_FUNCTION[ty]}(_arg{i});"
253+
for i, ty in signature.items()
254+
if ty in FLOAT_STORAGE_TYPE
255+
]
227256
params = [f"&arg{i}" for i, ty in signature.items() if ty != "constexpr"]
228257
params.append("&global_scratch")
229258
src = f"""
@@ -442,6 +471,32 @@ def format_of(ty):
442471
}}
443472
}}
444473
474+
static uint16_t pack_fp16(double f) {{
475+
uint16_t result;
476+
// from https://github.com/python/pythoncapi-compat
477+
#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION)
478+
_PyFloat_Pack2(f, (void*)&result, 1);
479+
#else
480+
PyFloat_Pack2(f, (void*)&result, 1);
481+
#endif
482+
return result;
483+
}}
484+
485+
static uint16_t pack_bf16(double f) {{
486+
float f32 = (float)f;
487+
uint32_t u32 = *(uint32_t*)&f32;
488+
return (uint16_t)(u32 >> 16);
489+
}}
490+
491+
static uint32_t pack_fp32(double f) {{
492+
float f32 = (float)f;
493+
return *(uint32_t*)&f32;
494+
}}
495+
496+
static uint64_t pack_fp64(double f) {{
497+
return *(uint64_t*)&f;
498+
}}
499+
445500
static PyObject* launch(PyObject* self, PyObject* args) {{
446501
// ensure cuda context is valid before calling any CUDA APIs, e.g. before getPointer calls cuPointerGetAttributes
447502
ensureCudaContext();
@@ -492,6 +547,7 @@ def format_of(ty):
492547
// raise exception asap
493548
{newline.join(ptr_decls)}
494549
{newline.join(tma_decls)}
550+
{newline.join(float_storage_decls)}
495551
Py_BEGIN_ALLOW_THREADS;
496552
_launch(gridX, gridY, gridZ, num_warps, num_ctas, launch_cooperative_grid, launch_pdl, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function, global_scratch{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
497553
Py_END_ALLOW_THREADS;

0 commit comments

Comments
 (0)