You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix #6176
```python
@triton.jit
def kernel(ptr, val: tl.float16):
tl.store(ptr, val)
ptr = torch.tensor([0.0], device="cuda:0")
kernel[1,](ptr, 42.0)
print(ptr)
# Expected: tensor([42.], device='cuda:0')
# Actual: tensor([0.], device='cuda:0')
```
The issue is caused by naively passing a Python float to a Triton kernel
that accepts `tl.float16`
Before this PR, the conversion chain for the float input looks like the
following:
```
PyArg_ParseTuple incorrectly passed to
PyFloat ================> float ----------!!!----------> kernel that accepts tl.float16
```
This PR always makes `PyArg_ParseTuple` to parse Python float to C
double, and then calls
[`PyFloat_Pack{2,4,8}`](https://docs.python.org/3/c-api/float.html#pack-functions)
to convert it to its proper storage type.
```
PyArg_ParseTuple PyFloat_Pack{2,4,8} passed to
PyFloat ==================> double ====================> uint{16,32,64}_t -------------> kernel that accepts tl.float{16,32,64}
```
The generated code snippet looks something like this (for AMD backend)
```c
double _arg1;
PyArg_ParseTuple(args, "piiiKKOOOOOd", ..., &_arg1);
uint16_t _arg1_storage = 0;
PyFloat_Pack2(_arg1, (void*)&_arg1_storage, 1);
_launch(gridX, gridY, gridZ, ..., _arg1_storage);
```
- [x] Fix AMD backend
- [x] Fix NVIDIA backend
- [x] Add tests
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, hipStream_t stream, hipFunction_t function{', '+arg_declsiflen(arg_decls) >0else''}) {{
385
-
// printf("_launch hip kernel\\n");
386
412
hipDeviceptr_t global_scratch = 0;
387
413
void *params[] = {{ {', '.join(params)} }};
388
414
if (gridX*gridY*gridZ > 0 && launch_cooperative_grid) {{
@@ -440,8 +466,33 @@ def format_of(ty):
440
466
return ptr_info;
441
467
}}
442
468
469
+
static uint16_t pack_fp16(double f) {{
470
+
uint16_t result;
471
+
// from https://github.com/python/pythoncapi-compat
0 commit comments