Skip to content

Commit 5d1c25f

Browse files
htyumeta-codesync[bot]
authored andcommitted
[TLX] Enable TMA desc creation pipelining (#715)
Summary: This diff adds the `reinterpret_tensor_descriptor` API to TLX for converting raw pointers to TMA tensor descriptors, and fixes the conditional purity modeling of `TT_MakeTensorDescOp`. 1. **New `reinterpret_tensor_descriptor` API** (`mem_ops.py`, `__init__.py`): - Takes a `desc_ptr` (pointer to global memory containing a TMA descriptor), `block_shape`, and `dtype` - Returns a `tl.tensor_descriptor_base` that can be used with TMA operations - Useful when working with pre-allocated descriptor pointers from `tlx.global_alloc` 2. **Fixed conditional purity for `TT_MakeTensorDescOp`** (`TritonOps.td`, `Ops.cpp`): - Removed the `Pure` trait (which cannot be conditional in MLIR ODS) - Implemented `MemoryEffectOpInterface::getEffects()` to conditionally add memory effects - When `descPtr` operand is present, adds `MemoryEffects::Write` to global memory - When `descPtr` is absent, no effects are added (operation is effectively pure) 3. **Added unit test** (`test_tlx.py`): - Tests the new `reinterpret_tensor_descriptor` API with TMA operations - Validates that both `ttg.global_scratch_alloc` and `ttng.reinterpret_tensor_descriptor` operations are generated - Verifies data correctness through TMA load/store operations The `reinterpret_tensor_descriptor` API enables tensor descriptor pipelining patterns where descriptors are allocated in global scratch memory and reused across kernel invocations. This is critical for performance optimization on Hopper/Blackwell GPUs. The conditional purity fix ensures that the MLIR compiler correctly models side effects: when a descriptor pointer is provided, the operation writes to global memory (impure); when auto-allocated, it has no observable side effects (pure). This follows the proper MLIR idiom since ODS traits are compile-time only and cannot be toggled at runtime. Pull Request resolved: #715 Test Plan: Existing tests pass: - `test_reinterpret_tensor_descriptor` validates the new API with TMA operations - Verifies correct MLIR operation generation (global_scratch_alloc, reinterpret_tensor_descriptor) - Confirms data correctness through TMA load/store round-trip The conditional memory effect modeling is validated by the MLIR compiler infrastructure which uses `getEffects()` during optimization passes. Performance: For groupedGEMM with memory-bound shapes such as G=16, M=8192, N=512, K=256 ``` before: x_val preprocessed_aten_grouped_mm-tflops tlx_grouped_gemm-tflops ------- ------------------------------------- ------------------------- 8192 283.609 243.148 after: x_val preprocessed_aten_grouped_mm-tflops tlx_grouped_gemm-tflops ------- ------------------------------------- ------------------------- 8192 283.459 274.755 ``` Reviewed By: manman-ren Differential Revision: D88292292 Pulled By: htyu fbshipit-source-id: ad1a087071166a12670b42511b2f5ffeea220610
1 parent 9346bb4 commit 5d1c25f

File tree

9 files changed

+475
-122
lines changed

9 files changed

+475
-122
lines changed

README.md

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,33 @@ While this approach places more responsibility on the user, it reduces the compi
6969

7070
Store a chunk of data from local memory into global memory buffer. The global address, strides, and buffer size are defined by the memory descriptor.
7171

72+
- `desc_ptrs = tlx.allocate_tensor_descriptor(num)`
73+
74+
Allocates global memory for tensor descriptor storage with built-in parameters (nbytes=128, alignment=128 per descriptor).
75+
Returns a `tensor_descriptor_ptr` with 128-byte stride semantics that supports indexing.
76+
77+
**Parameters:**
78+
- `num`: Number of tensor descriptors to allocate (must be a constexpr)
79+
80+
**Returns:**
81+
- A `tensor_descriptor_ptr` where indexing (e.g., `desc_ptrs[0]`, `desc_ptrs[1]`) advances by 128 bytes per index
82+
83+
**Example:**
84+
```python
85+
# Allocate storage for 4 tensor descriptors
86+
desc_ptrs = tlx.allocate_tensor_descriptor(num=4)
87+
88+
# Access individual descriptors using indexing
89+
desc_ptr_0 = desc_ptrs[0] # First descriptor
90+
desc_ptr_1 = desc_ptrs[1] # Second descriptor (128 bytes offset)
91+
```
92+
7293
- `tlx.make_tensor_descriptor(desc_ptr, base, shape, strides, block_shape, padding_option)`
7394

7495
Create a TMA (Tensor Memory Accelerator) descriptor for efficient asynchronous data movement on Hopper and Blackwell GPUs.
7596

7697
**Parameters:**
77-
- `desc_ptr` (optional): Pointer to global memory for descriptor storage. Pass `None` for automatic allocation.
98+
- `desc_ptr` (optional): Tensor descriptor pointer from `allocate_tensor_descriptor()`. Pass `None` for automatic allocation.
7899
- `base`: Base pointer to the tensor in global memory
79100
- `shape`: List of tensor dimensions (dynamic, runtime values)
80101
- `strides`: List of tensor strides (dynamic, runtime values)
@@ -92,15 +113,47 @@ While this approach places more responsibility on the user, it reduces the compi
92113
block_shape=[64, 64],
93114
)
94115

95-
# Or with explicit scratch allocation for advanced use cases
96-
desc_ptr = tlx.global_alloc(nbytes=128, alignment=128)
97-
desc = tlx.make_tensor_descriptor(
98-
desc_ptr=desc_ptr,
116+
# Or with explicit descriptor allocation for advanced use cases (e.g., pipelining)
117+
desc_ptrs = tlx.allocate_tensor_descriptor(num=2)
118+
119+
# Create descriptor at index 0
120+
tlx.make_tensor_descriptor(
121+
desc_ptr=desc_ptrs[0],
99122
base=tensor_ptr,
100123
shape=[M, N],
101124
strides=[N, tl.constexpr(1)],
102125
block_shape=[64, 64],
103126
)
127+
128+
# Reinterpret the descriptor for TMA operations
129+
desc = tlx.reinterpret_tensor_descriptor(
130+
desc_ptr=desc_ptrs[0],
131+
block_shape=[64, 64],
132+
dtype=tl.float16,
133+
)
134+
135+
# Use with async TMA operations
136+
tlx.async_descriptor_load(desc, buffer, offsets=[m_offset, n_offset], barrier=mbar)
137+
```
138+
139+
- `desc = tlx.reinterpret_tensor_descriptor(desc_ptr, block_shape, dtype)`
140+
141+
Reinterpret a tensor descriptor pointer as a TMA-backed tensor descriptor object.
142+
143+
**Parameters:**
144+
- `desc_ptr`: A `tensor_descriptor_ptr` pointing to the TMA descriptor (from `allocate_tensor_descriptor`)
145+
- `block_shape`: Shape of the block to be loaded/stored (compile-time constants)
146+
- `dtype`: Data type of the tensor elements
147+
148+
**Example:**
149+
```python
150+
# Allocate and create descriptor
151+
desc_ptrs = tlx.allocate_tensor_descriptor(num=2)
152+
tlx.make_tensor_descriptor(desc_ptr=desc_ptrs[0], base=a_ptr, shape=[M, K], strides=[K, 1], block_shape=[128, 64])
153+
154+
# Reinterpret for use with TMA
155+
a_desc = tlx.reinterpret_tensor_descriptor(desc_ptr=desc_ptrs[0], block_shape=[128, 64], dtype=tl.float16)
156+
tlx.async_descriptor_load(a_desc, buffer, offsets=[offs_m, offs_k], barrier=mbar)
104157
```
105158

106159
- `tlx.async_load(tensor_ptr, buffer, optional_mask, optional_other, cache_modifier, eviction_policy, is_volatile)`

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1053,8 +1053,8 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
10531053
// Make Tensor Descriptor Op
10541054
//
10551055
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
1056-
Pure,
10571056
AttrSizedOperandSegments,
1057+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
10581058
]> {
10591059
let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";
10601060

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,16 @@ void MakeTensorDescOp::print(OpAsmPrinter &p) {
12881288
p << " : " << getBase().getType() << ", " << getType();
12891289
}
12901290

1291+
void MakeTensorDescOp::getEffects(
1292+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1293+
&effects) {
1294+
// If descPtr operand is present, this operation writes to global memory
1295+
if (getDescPtr()) {
1296+
effects.emplace_back(MemoryEffects::Write::get(), GlobalMemory::get());
1297+
}
1298+
// Otherwise, the operation is pure (no effects)
1299+
}
1300+
12911301
// The following ops, including `call`, `func`, and `return` are copied and
12921302
// modified from
12931303
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Func/IR/FuncOps.cpp

python/test/unit/language/test_tlx.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2634,7 +2634,7 @@ def stoch_round_seed_kernel(x_ptr, y_ptr, seed, BLOCK_SIZE: tl.constexpr):
26342634

26352635
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
26362636
def test_make_tensor_descriptor(device):
2637-
"""Test global_alloc and make_tensor_descriptor together with TMA operations."""
2637+
"""Test allocate_tensor_descriptor and make_tensor_descriptor together with TMA operations."""
26382638

26392639
def alloc_fn(size: int, align: int, stream: Optional[int]):
26402640
assert align == 128
@@ -2643,20 +2643,20 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
26432643

26442644
@triton.jit
26452645
def kernel(input_ptr, output_ptr, SIZE, BLOCK_SIZE: tl.constexpr):
2646-
# Allocate descriptor in global scratch memory using global_alloc
2647-
desc_ptr = tlx.global_alloc(nbytes=256, alignment=128)
2646+
# Allocate descriptor in global scratch memory using allocate_tensor_descriptor
2647+
desc_ptrs = tlx.allocate_tensor_descriptor(num=2)
26482648

26492649
# Create tensor descriptor using the global scratch pointer
2650-
desc_in = tlx.make_tensor_descriptor(
2651-
desc_ptr=desc_ptr,
2650+
tlx.make_tensor_descriptor(
2651+
desc_ptr=desc_ptrs[0],
26522652
base=input_ptr,
26532653
shape=[SIZE],
26542654
strides=[tl.constexpr(1)],
26552655
block_shape=[BLOCK_SIZE],
26562656
)
26572657

2658-
desc_out = tlx.make_tensor_descriptor(
2659-
desc_ptr=desc_ptr + 128,
2658+
tlx.make_tensor_descriptor(
2659+
desc_ptr=desc_ptrs[1],
26602660
base=output_ptr,
26612661
shape=[SIZE],
26622662
strides=[tl.constexpr(1)],
@@ -2668,6 +2668,17 @@ def kernel(input_ptr, output_ptr, SIZE, BLOCK_SIZE: tl.constexpr):
26682668
offset = pid * BLOCK_SIZE
26692669

26702670
# Load and store using standard descriptors
2671+
# Reinterpret pointers as tensor descriptors
2672+
desc_in = tlx.reinterpret_tensor_descriptor(
2673+
desc_ptr=desc_ptrs[0],
2674+
block_shape=[BLOCK_SIZE],
2675+
dtype=tlx.dtype_of(input_ptr),
2676+
)
2677+
desc_out = tlx.reinterpret_tensor_descriptor(
2678+
desc_ptr=desc_ptrs[1],
2679+
block_shape=[BLOCK_SIZE],
2680+
dtype=tlx.dtype_of(output_ptr),
2681+
)
26712682
x = desc_in.load([offset])
26722683
desc_out.store([offset], x)
26732684

@@ -2684,6 +2695,7 @@ def kernel(input_ptr, output_ptr, SIZE, BLOCK_SIZE: tl.constexpr):
26842695
ttgir = compiled_kernel.asm["ttgir"]
26852696
assert ttgir.count("ttg.global_scratch_alloc") == 1, "Expected 1 global_scratch_alloc operation"
26862697
assert ttgir.count("ttng.tensormap_create") == 2, "Expected 2 tensormap_create operations"
2698+
assert ttgir.count("ttng.reinterpret_tensor_descriptor") == 2, "Expected 2 reinterpret_tensor_descriptor operations"
26872699

26882700
# Verify the data was copied correctly through TMA operations
26892701
torch.testing.assert_close(x, y)

test/TritonNvidiaGPU/tma_lowering.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,60 @@ tt.func @tma_scatter(%arg0: !tt.tensordesc<tensor<1x128xbf16, #nvmma_128>>, %arg
103103
// CHECK-NEXT: ttng.async_tma_store_wait
104104
tt.descriptor_scatter %arg0[%arg1, %arg2], %arg3 : !tt.tensordesc<tensor<1x128xbf16, #nvmma_128>>, tensor<32xi32, #blocked>, i32, tensor<32x128xbf16, #blocked1>
105105
tt.return
106+
}
107+
106108
}
107109

110+
// -----
111+
112+
#nvmma_32 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
113+
114+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
115+
// Test that MakeTensorDescOp without descPtr has no memory effects (pure)
116+
// This enables CSE - duplicate operations with identical inputs can be eliminated
117+
// CHECK-LABEL: make_tensor_descriptor_pure
118+
tt.func public @make_tensor_descriptor_pure(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) -> !tt.tensordesc<tensor<8x32xi8, #nvmma_32>> {
119+
%c1_i64 = arith.constant 1 : i64
120+
%0 = arith.extsi %arg2 : i32 to i64
121+
// Without descPtr, the operation has no observable side effects
122+
// Both calls have identical inputs, so CSE should eliminate one
123+
%1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : !tt.ptr<i8>, !tt.tensordesc<tensor<8x32xi8, #nvmma_32>>
124+
%2 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : !tt.ptr<i8>, !tt.tensordesc<tensor<8x32xi8, #nvmma_32>>
125+
// CHECK: %[[ALLOC:.*]] = ttg.global_scratch_alloc
126+
// CHECK: ttng.tensormap_create %[[ALLOC]]
127+
// CHECK: ttng.tensormap_fenceproxy_acquire %[[ALLOC]]
128+
// CHECK: %[[DESC:.*]] = ttng.reinterpret_tensor_descriptor %[[ALLOC]]
129+
// CHECK-NOT: ttg.global_scratch_alloc
130+
// CHECK-NOT: ttng.tensormap_create
131+
// Both operations should be CSE'd into a single descriptor due to purity
132+
tt.return %1 : !tt.tensordesc<tensor<8x32xi8, #nvmma_32>>
133+
}
134+
}
135+
136+
// -----
137+
138+
#nvmma_32 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 8}>
139+
140+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
141+
// Test that MakeTensorDescOp with descPtr has memory effects (impure)
142+
// This prevents CSE - operations writing to different locations must be preserved
143+
// CHECK-LABEL: make_tensor_descriptor_impure
144+
tt.func public @make_tensor_descriptor_impure(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<i8> {tt.divisibility = 16 : i32}) -> (!tt.tensordesc<tensor<8x32xi8, #nvmma_32>>, !tt.tensordesc<tensor<8x32xi8, #nvmma_32>>) {
145+
%c1_i64 = arith.constant 1 : i64
146+
%0 = arith.extsi %arg2 : i32 to i64
147+
// With descPtr, the operation writes to global memory (impure)
148+
// Both operations write to different locations (arg3 vs arg4), so both must be preserved
149+
%1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64], descPtr = %arg3 : !tt.ptr<i8> : !tt.ptr<i8>, !tt.tensordesc<tensor<8x32xi8, #nvmma_32>>
150+
%2 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64], descPtr = %arg4 : !tt.ptr<i8> : !tt.ptr<i8>, !tt.tensordesc<tensor<8x32xi8, #nvmma_32>>
151+
// CHECK: ttng.tensormap_create %arg3
152+
// CHECK: ttng.tensormap_fenceproxy_acquire %arg3
153+
// CHECK: %[[DESC1:.*]] = ttng.reinterpret_tensor_descriptor %arg3
154+
// CHECK: ttng.tensormap_create %arg4
155+
// CHECK: ttng.tensormap_fenceproxy_acquire %arg4
156+
// CHECK: %[[DESC2:.*]] = ttng.reinterpret_tensor_descriptor %arg4
157+
// Both operations must be preserved (no CSE) due to impurity
158+
tt.return %1, %2 : !tt.tensordesc<tensor<8x32xi8, #nvmma_32>>, !tt.tensordesc<tensor<8x32xi8, #nvmma_32>>
159+
}
108160
}
109161

110162
// -----

third_party/tlx/language/tlx/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
clc_response_type,
1515
CLCPipelineContext,
1616
async_token,
17+
tensor_descriptor_ptr,
18+
tensor_descriptor_ptr_type,
1719
)
1820
from .mem_ops import (
1921
local_alloc,
@@ -28,12 +30,13 @@
2830
local_store,
2931
local_trans,
3032
local_reinterpret,
31-
global_alloc,
33+
allocate_tensor_descriptor,
3234
async_descriptor_load,
3335
async_descriptor_store,
3436
async_descriptor_store_wait,
3537
fence_async_shared,
3638
make_tensor_descriptor,
39+
reinterpret_tensor_descriptor,
3740
)
3841
from .barrier import (
3942
alloc_barriers,
@@ -88,6 +91,8 @@
8891
"clc_response_type",
8992
"CLCPipeliner",
9093
"async_token",
94+
"tensor_descriptor_ptr",
95+
"tensor_descriptor_ptr_type",
9196
# mem_ops
9297
"local_alloc",
9398
"local_view",
@@ -101,12 +106,13 @@
101106
"local_store",
102107
"local_trans",
103108
"local_reinterpret",
104-
"global_alloc",
109+
"allocate_tensor_descriptor",
105110
"async_descriptor_load",
106111
"async_descriptor_store",
107112
"async_descriptor_store_wait",
108113
"fence_async_shared",
109114
"make_tensor_descriptor",
115+
"reinterpret_tensor_descriptor",
110116
# barriers
111117
"alloc_barriers",
112118
"barrier_expect_bytes",

0 commit comments

Comments
 (0)