You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add implicit downcast in TMA descriptor store (#6236)
#### Description
This fixes a missing implicit downcast when storing blocks through TMA
descriptors. Previously, attempting to widen the result of a descriptor
load (e.g. from `float16` to `float32`) and then store it back via the
descriptor would result in an MLIR verification error because the block
types no longer matched:
```python
# ptr.element_ty is tl.float16
desc = tl._experimental_make_tensor_descriptor(ptr, shape=.., strides=..., block_shape=...)
value = desc.load([off_x, off_y]).to(tl.float32)
# 'tt.experimental_descriptor_store' op tensor desciptor block and tensor types must match
desc.store([off_x, off_y], value)
```
The pointer/`tl.store` path already cast values to the target element
type; descriptor stores should behave the same.
#### Changes
* Updated `descriptor_store` in `python/triton/language/semantic.py` to
cast the incoming tensor to the descriptor's element type before
emitting the `create_descriptor_store` IR node.
* Added a regression test `test_tensor_descriptor_store_downcast` to
`python/test/unit/cuda/test_experimental_tma.py` which widens a
`float16`/`bfloat16` block to `float32` and stores it back via the
descriptor.
* Ran `pre-commit` hooks to keep formatting consistent.
A quick check under `TRITON_INTERPRET=1` shows the new downcast path
works:
```
True # torch.equal(a, out) when storing a widened float16 block
True # bfloat16 as well
```
#### Checklist
- [x] I am not making a trivial change, such as fixing a typo in a
comment.
- [x] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
- [x] I have added tests under `python/test`.
- [x] I have not added any `lit` tests.
---
This fix aligns descriptor stores with pointer store semantics and
avoids an IR verifier failure when the stored block's element type is
wider than the descriptor’s element type.
Co-authored-by: Thomas Raoux <[email protected]>
0 commit comments