You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
This diff adds the `reinterpret_tensor_descriptor` API to TLX for converting raw pointers to TMA tensor descriptors, and fixes the conditional purity modeling of `TT_MakeTensorDescOp`.
1. **New `reinterpret_tensor_descriptor` API** (`mem_ops.py`, `__init__.py`):
- Takes a `desc_ptr` (pointer to global memory containing a TMA descriptor), `block_shape`, and `dtype`
- Returns a `tl.tensor_descriptor_base` that can be used with TMA operations
- Useful when working with pre-allocated descriptor pointers from `tlx.global_alloc`
2. **Fixed conditional purity for `TT_MakeTensorDescOp`** (`TritonOps.td`, `Ops.cpp`):
- Removed the `Pure` trait (which cannot be conditional in MLIR ODS)
- Implemented `MemoryEffectOpInterface::getEffects()` to conditionally add memory effects
- When `descPtr` operand is present, adds `MemoryEffects::Write` to global memory
- When `descPtr` is absent, no effects are added (operation is effectively pure)
3. **Added unit test** (`test_tlx.py`):
- Tests the new `reinterpret_tensor_descriptor` API with TMA operations
- Validates that both `ttg.global_scratch_alloc` and `ttng.reinterpret_tensor_descriptor` operations are generated
- Verifies data correctness through TMA load/store operations
The `reinterpret_tensor_descriptor` API enables tensor descriptor pipelining patterns where descriptors are allocated in global scratch memory and reused across kernel invocations. This is critical for performance optimization on Hopper/Blackwell GPUs.
The conditional purity fix ensures that the MLIR compiler correctly models side effects: when a descriptor pointer is provided, the operation writes to global memory (impure); when auto-allocated, it has no observable side effects (pure). This follows the proper MLIR idiom since ODS traits are compile-time only and cannot be toggled at runtime.
Pull Request resolved: #715
Test Plan:
Existing tests pass:
- `test_reinterpret_tensor_descriptor` validates the new API with TMA operations
- Verifies correct MLIR operation generation (global_scratch_alloc, reinterpret_tensor_descriptor)
- Confirms data correctness through TMA load/store round-trip
The conditional memory effect modeling is validated by the MLIR compiler infrastructure which uses `getEffects()` during optimization passes.
Performance:
For groupedGEMM with memory-bound shapes such as G=16, M=8192, N=512, K=256
```
before:
x_val preprocessed_aten_grouped_mm-tflops tlx_grouped_gemm-tflops
------- ------------------------------------- -------------------------
8192 283.609 243.148
after:
x_val preprocessed_aten_grouped_mm-tflops tlx_grouped_gemm-tflops
------- ------------------------------------- -------------------------
8192 283.459 274.755
```
Reviewed By: manman-ren
Differential Revision: D88292292
Pulled By: htyu
fbshipit-source-id: ad1a087071166a12670b42511b2f5ffeea220610
0 commit comments