Skip to content

Commit 286e91f

Browse files
Add implicit downcast in TMA descriptor store (#6236)
#### Description This fixes a missing implicit downcast when storing blocks through TMA descriptors. Previously, attempting to widen the result of a descriptor load (e.g. from `float16` to `float32`) and then store it back via the descriptor would result in an MLIR verification error because the block types no longer matched: ```python # ptr.element_ty is tl.float16 desc = tl._experimental_make_tensor_descriptor(ptr, shape=.., strides=..., block_shape=...) value = desc.load([off_x, off_y]).to(tl.float32) # 'tt.experimental_descriptor_store' op tensor desciptor block and tensor types must match desc.store([off_x, off_y], value) ``` The pointer/`tl.store` path already cast values to the target element type; descriptor stores should behave the same. #### Changes * Updated `descriptor_store` in `python/triton/language/semantic.py` to cast the incoming tensor to the descriptor's element type before emitting the `create_descriptor_store` IR node. * Added a regression test `test_tensor_descriptor_store_downcast` to `python/test/unit/cuda/test_experimental_tma.py` which widens a `float16`/`bfloat16` block to `float32` and stores it back via the descriptor. * Ran `pre-commit` hooks to keep formatting consistent. A quick check under `TRITON_INTERPRET=1` shows the new downcast path works: ``` True # torch.equal(a, out) when storing a widened float16 block True # bfloat16 as well ``` #### Checklist - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - [x] I have added tests under `python/test`. - [x] I have not added any `lit` tests. --- This fix aligns descriptor stores with pointer store semantics and avoids an IR verifier failure when the stored block's element type is wider than the descriptor’s element type. Co-authored-by: Thomas Raoux <[email protected]>
1 parent ba3ec66 commit 286e91f

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

python/test/unit/language/test_tensor_descriptor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,3 +1671,30 @@ def test_host_tensor_descriptor_matmul(num_stages, num_ctas, BLOCK_M, BLOCK_N, B
16711671
# Only a subset of TMEM and stmatrix layout pairs are compatible, for example 16x256bx2 and m8n8x4.
16721672
assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in kernel.asm[
16731673
"ptx"] or "stmatrix.sync.aligned.x4.m8n8.shared.b16" in kernel.asm["ptx"]
1674+
1675+
1676+
@pytest.mark.interpreter
1677+
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
1678+
def test_tensor_descriptor_store_downcast(dtype_str, device):
1679+
1680+
@triton.jit
1681+
def kernel(desc, M, N, M_BLOCK: tl.constexpr, N_BLOCK: tl.constexpr):
1682+
moffset = tl.program_id(axis=0) * M_BLOCK
1683+
noffset = tl.program_id(axis=1) * N_BLOCK
1684+
midx = moffset + tl.arange(0, M_BLOCK)[:, None]
1685+
nidx = noffset + tl.arange(0, N_BLOCK)[None, :]
1686+
val_f32 = (midx * N + nidx).to(tl.float32)
1687+
# implicit downcast in the store.
1688+
desc.store([moffset, noffset], val_f32)
1689+
1690+
M, N = 32, 128
1691+
torch_dtype = getattr(torch, dtype_str)
1692+
M_BLOCK = 8
1693+
N_BLOCK = 32
1694+
grid_m = M // M_BLOCK
1695+
grid_n = N // N_BLOCK
1696+
out = torch.empty((M, N), dtype=torch_dtype, device=device)
1697+
desc = TensorDescriptor(out, out.shape, out.stride(), [M_BLOCK, N_BLOCK])
1698+
kernel[(grid_m, grid_n)](desc, M, N, M_BLOCK=M_BLOCK, N_BLOCK=N_BLOCK)
1699+
ref = torch.arange(M * N, dtype=torch.float32, device=device).reshape(M, N).to(torch_dtype)
1700+
torch.testing.assert_close(out, ref)

python/triton/language/semantic.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,8 @@ def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy,
11071107

11081108
def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
11091109
self.validate_store_like(desc, value, offsets)
1110+
# implicitly cast to the descriptor's type
1111+
value = self.cast(value, desc.dtype)
11101112
offsets = self._convert_to_ir_values(offsets, require_i64=False)
11111113
return self.tensor(self.builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)
11121114

0 commit comments

Comments
 (0)