Skip to content

Commit 239f920

Browse files
authored
Add mask to tl.device_assert (triton-lang#7905)
<!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. Add `mask=mask` to `tl.device_assert` Frequently in Triton kernels, data is loaded with a mask. The data outside the mask has unknown value. This makes it frustrating to write `tl.device_assert`s. In particular, I have to use this pattern: `tl.device_assert((data == 1337) | ~mask)`. This works fine; it's just a little unpleasant, and you have to remember that `==` has lower precedence than `|`. With this PR, you can instead write `tl.device_assert(data == 1337, mask=mask)`. Representative example: ```py import torch import triton.language as tl import triton torch.set_default_device("cuda") @triton.jit def kernel(data_ptr, len): BLOCK_SZ: tl.constexpr = 32 mask = tl.arange(0, BLOCK_SZ) < len data = tl.load(data_ptr + tl.arange(0, BLOCK_SZ), mask=mask) tl.device_assert(data == 1337, mask=mask) tensor = torch.full((10,), 1337, dtype=torch.int32) print("succeeds:") kernel[(1,)](tensor, len(tensor)) print("fails:") kernel[(1,)](tensor, len(tensor) + 1) # the next int is probably not 1337 ```
1 parent 95792de commit 239f920

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

python/test/unit/test_debug.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,32 @@
55

66

77
@pytest.mark.parametrize('cond', [True, False])
8+
@pytest.mark.parametrize('mask', [True, False, None])
89
@pytest.mark.parametrize('opt_flag', [True, False, None])
910
@pytest.mark.parametrize('env_var', [True, False])
1011
@pytest.mark.parametrize('jit_flag', [True, False])
1112
@pytest.mark.forked
12-
def test_device_assert(monkeypatch, cond, opt_flag, env_var, jit_flag, device):
13+
def test_device_assert(monkeypatch, cond, mask, opt_flag, env_var, jit_flag, device):
1314
monkeypatch.setenv("TRITON_DEBUG", str(int(env_var)))
1415
torch.zeros([1], dtype=torch.int32, device=device)
1516

1617
@triton.jit(debug=jit_flag)
17-
def _kernel(COND: tl.constexpr):
18-
tl.device_assert(COND, 'test')
18+
def _kernel(COND: tl.constexpr, MASK: tl.constexpr):
19+
tl.device_assert(COND, 'test', mask=MASK)
1920

2021
is_debug = env_var or (opt_flag if opt_flag is not None else jit_flag)
2122

2223
kwargs = {}
2324
if opt_flag is not None:
2425
kwargs["debug"] = opt_flag
2526

26-
if not cond and is_debug:
27+
if not cond and is_debug and mask is not False:
2728
with pytest.raises(RuntimeError):
28-
_kernel[(1, )](cond, **kwargs)
29+
_kernel[(1, )](cond, mask, **kwargs)
2930
getattr(torch, device).synchronize()
3031
return
3132

32-
_kernel[(1, )](cond, **kwargs)
33+
_kernel[(1, )](cond, mask, **kwargs)
3334
getattr(torch, device).synchronize()
3435

3536

python/triton/language/core.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2984,7 +2984,7 @@ def device_print(prefix, *args, hex=False, _semantic=None):
29842984

29852985

29862986
@builtin
2987-
def device_assert(cond, msg="", _semantic=None):
2987+
def device_assert(cond, msg="", mask=None, _semantic=None):
29882988
'''
29892989
Assert the condition at runtime from the device. Requires that the environment variable :code:`TRITON_DEBUG`
29902990
is set to a value besides :code:`0` in order for this to have any effect.
@@ -3003,7 +3003,10 @@ def device_assert(cond, msg="", _semantic=None):
30033003
:param msg: the message to print if the assertion fails. This is required to be a string literal.
30043004
'''
30053005
msg = _unwrap_if_constexpr(msg)
3006-
return _semantic.device_assert(_semantic.to_tensor(cond), msg)
3006+
mask = _unwrap_if_constexpr(mask)
3007+
if mask is not None:
3008+
mask = _semantic.to_tensor(mask)
3009+
return _semantic.device_assert(_semantic.to_tensor(cond), msg, mask)
30073010

30083011

30093012
@builtin

python/triton/language/semantic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_
219219
min_value = self.scalar_constant(min_value, tl.int64)
220220
cond = self.and_(self.less_equal(ret, max_value), self.greater_equal(ret, min_value))
221221
msg = f"int{lhs_sca_ty.int_bitwidth} overflow detected for operation {binary_op.__name__}"
222-
self.device_assert(cond, msg)
222+
self.device_assert(cond, msg, None)
223223

224224
def add(self, input: TensorTy | numbers.Number, other: TensorTy | numbers.Number,
225225
sanitize_overflow: bool) -> TensorTy:
@@ -1808,9 +1808,11 @@ def device_print(self, prefix: str, args: List[TensorTy], hex: bool) -> TensorTy
18081808
is_signed = [arg.dtype.is_int_signed() for arg in args]
18091809
return self.tensor(self.builder.create_print(prefix, hex, new_args, is_signed), tl.void)
18101810

1811-
def device_assert(self, cond: TensorTy, msg: str) -> TensorTy:
1811+
def device_assert(self, cond: TensorTy, msg: str, mask: Optional[TensorTy]) -> TensorTy:
18121812
if not self.builder.options.debug:
18131813
return
1814+
if mask is not None:
1815+
cond = self.or_(cond, self.not_(mask))
18141816
return self.tensor(self.builder.create_assert(cond.handle, msg), tl.void)
18151817

18161818
def assume(self, cond) -> TensorTy:

0 commit comments

Comments
 (0)