You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add extra attribute in PrintOp to propagate signness info (intel#4363)
The core Triton is a small number of people, and we receive many PRs
(thank
you!). To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**
--------
This PR aims to address
triton-lang/triton#4248, to correctly
`device_print` the value if it is an signed integer.
The signness info is lost when lowering TTGIR to LLIR (e.g. `i32` is
always **signless** in MLIR), but the **lowered** data type is currently
being used for constructing the format specifier in the `PrintOpToLLVM`
implementation
(triton-lang/triton#4248 (comment)),
so a negative value is printed out as an unsigned int, thus confusing
users.
A minimal reproducer is
```python
import torch
import triton
import triton.language as tl
@triton.jit
def print_kernel(ptr):
value = tl.load(ptr)
tl.device_print("value in kernel from device_print", value)
print_kernel[(1,)](torch.tensor(10, dtype=torch.int32).cuda())
print_kernel[(1,)](torch.tensor(-10, dtype=torch.int32).cuda())
print_kernel[(1,)](torch.tensor((1 << 31) + 1000, dtype=torch.uint32).cuda())
```
Currently, it prints
```
pid (0, 0, 0) idx () value in kernel from device_print: 10
...
pid (0, 0, 0) idx () value in kernel from device_print: 4294967286
...
pid (0, 0, 0) idx () value in kernel from device_print: 2147484648
```
(always as unsigned int)
This PR adds extra `isSigned` attribute in the `PrintOp` to indicate if
each operand in the `PrintOp` should be printed as signed or not.
With this, the program above now prints correctly
```
pid (0, 0, 0) idx () value in kernel from device_print: 10
...
pid (0, 0, 0) idx () value in kernel from device_print: -10
...
pid (0, 0, 0) idx () value in kernel from device_print: 2147484648
```
Extra LIT tests and python unit tests are added as well; also manually
verified that they failed without the fix and passing now by running
```
$ pytest python/test/unit/language/test_subprocess.py
$ cd python/build/cmake.linux-x86_64-cpython-3.10; lit test
```
**Alternative considered**: adds `uint32` in the triton MLIR data type
definition and then rely on the triton op data type to determine the
format specifier, to retain the original signness info; as in commit
triton-lang/triton@f7a7407.
However, as PR reviewer pointed out, that means adding a new data type
in the Triton IR just for this purpose, which is overkill and introduces
unnecessary maintenance overhead and thus less ideal.
--------
Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
- [X] I am not making a trivial change, such as fixing a typo in a
comment.
- [X] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
- Select one of the following.
- [X] I have added tests.
- `/test` for `lit` tests
- `/unittest` for C++ tests
- `/python/test` for end-to-end tests
- [ ] This PR does not need a test because `FILL THIS IN`.
- Select one of the following.
- [ ] I have not added any `lit` tests.
- [X] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
and using the instructions it generates is not minimal.)
0 commit comments