Skip to content

Commit e2f6663

Browse files
htyumeta-codesync[bot]
authored andcommitted
[TLX] Enable tensor descriptor pipelining (#706)
Summary: This PR enables tensor descriptor pipelining in TLX to improve performance of TMA operations on Hopper and Blackwell GPUs. The implementation includes a new make_tensor_descriptor API with custom MLIR parsing and support for automatic scratch memory allocation. More specifically, the Tensor Descriptor Pipelining Infrastructure includes: - Implemented pipelining support for tensor descriptors to enable efficient asynchronous data movement - Added support for automatic scratch memory allocation for descriptor storage - Updated TMA lowering pass to handle pipelined descriptor operations - New `tlx.make_tensor_descriptor` API Example usage ``` # For cases requiring manual memory management desc_ptr = tlx.global_alloc(nbytes=128, alignment=128) desc = tlx.make_tensor_descriptor( desc_ptr=desc_ptr, base=tensor_ptr, shape=[M, N], strides=[N, tl.constexpr(1)], block_shape=[64, 64], padding_option="zero", # Handle out-of-bounds accesses ) # Use the descriptor with async load/store operations buffer = tl.zeros([64, 64], dtype=tl.float16) tlx.async_descriptor_load(desc, buffer, [row_offset, col_offset]) ``` Changes are made to existing lit tests to maintain test compatibility with the new parser to avoid supporting legacy type parsing. Actually upstream already started adding new lit tests in that way. Pull Request resolved: #706 Reviewed By: njriasan Differential Revision: D88103067 Pulled By: htyu fbshipit-source-id: 0ca3340add4bae9693b81929c612e91499bb9b84
1 parent 37788c2 commit e2f6663

File tree

22 files changed

+479
-79
lines changed

22 files changed

+479
-79
lines changed

README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,39 @@ While this approach places more responsibility on the user, it reduces the compi
6969

7070
Store a chunk of data from local memory into global memory buffer. The global address, strides, and buffer size are defined by the memory descriptor.
7171

72+
- `tlx.make_tensor_descriptor(desc_ptr, base, shape, strides, block_shape, padding_option)`
73+
74+
Create a TMA (Tensor Memory Accelerator) descriptor for efficient asynchronous data movement on Hopper and Blackwell GPUs.
75+
76+
**Parameters:**
77+
- `desc_ptr` (optional): Pointer to global memory for descriptor storage. Pass `None` for automatic allocation.
78+
- `base`: Base pointer to the tensor in global memory
79+
- `shape`: List of tensor dimensions (dynamic, runtime values)
80+
- `strides`: List of tensor strides (dynamic, runtime values)
81+
- `block_shape`: Shape of the block to be loaded/stored (compile-time constants)
82+
- `padding_option`: Padding option for out-of-bounds accesses (default: "zero")
83+
84+
**Example:**
85+
```python
86+
# Create a 2D tensor descriptor with automatic scratch allocation
87+
desc = tlx.make_tensor_descriptor(
88+
desc_ptr=None, # Compiler allocates scratch memory automatically
89+
base=tensor_ptr,
90+
shape=[M, N],
91+
strides=[N, tl.constexpr(1)],
92+
block_shape=[64, 64],
93+
)
94+
95+
# Or with explicit scratch allocation for advanced use cases
96+
desc_ptr = tlx.global_alloc(nbytes=128, alignment=128)
97+
desc = tlx.make_tensor_descriptor(
98+
desc_ptr=desc_ptr,
99+
base=tensor_ptr,
100+
shape=[M, N],
101+
strides=[N, tl.constexpr(1)],
102+
block_shape=[64, 64],
103+
)
104+
```
72105

73106
- `tlx.async_load(tensor_ptr, buffer, optional_mask, optional_other, cache_modifier, eviction_policy, is_volatile)`
74107

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,7 +1054,7 @@ def TT_MakeTensorPtrOp : TT_Op<"make_tensor_ptr",
10541054
//
10551055
def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
10561056
Pure,
1057-
SameVariadicOperandSize,
1057+
AttrSizedOperandSegments,
10581058
]> {
10591059
let summary = "Make a tensor descriptor type with meta information of the parent tensor and block size";
10601060

@@ -1067,15 +1067,18 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
10671067
TT_Ptr:$base,
10681068
Variadic<I32>:$shape,
10691069
Variadic<I64>:$strides,
1070+
Optional<TT_Ptr>:$descPtr,
10701071
DefaultValuedAttr<TT_PaddingOptionAttr, "::mlir::triton::PaddingOption::PAD_ZERO">:$padding
10711072
);
10721073

10731074
let results = (outs TT_TensorDescType:$result);
10741075

1075-
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
1076+
let hasCustomAssemblyFormat = 1;
10761077

10771078
let builders = [
10781079
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger,
1080+
"triton::PaddingOption":$padding)>,
1081+
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "Value":$descPtr, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger,
10791082
"triton::PaddingOption":$padding)>
10801083
];
10811084

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,143 @@ void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
11491149
auto descTy =
11501150
TensorDescType::get(builder.getContext(), blockTy, isSignedInteger);
11511151
auto paddingAttr = PaddingOptionAttr::get(builder.getContext(), padding);
1152-
return build(builder, state, descTy, base, shape, strides, paddingAttr);
1152+
return build(builder, state, descTy, base, shape, strides,
1153+
/*descPtr=*/Value(), paddingAttr);
1154+
}
1155+
1156+
void MakeTensorDescOp::build(OpBuilder &builder, OperationState &state,
1157+
Value base, ValueRange shape, ValueRange strides,
1158+
Value descPtr, ArrayRef<int32_t> blockShape,
1159+
bool isSignedInteger,
1160+
triton::PaddingOption padding) {
1161+
auto ptrTy = dyn_cast<triton::PointerType>(base.getType());
1162+
if (!ptrTy) {
1163+
llvm::report_fatal_error("Expected pointer type");
1164+
}
1165+
auto elemTy = ptrTy.getPointeeType();
1166+
SmallVector<int64_t> blockShape64(blockShape);
1167+
auto blockTy = RankedTensorType::get(blockShape64, elemTy);
1168+
auto descTy =
1169+
TensorDescType::get(builder.getContext(), blockTy, isSignedInteger);
1170+
auto paddingAttr = PaddingOptionAttr::get(builder.getContext(), padding);
1171+
return build(builder, state, descTy, base, shape, strides, descPtr,
1172+
paddingAttr);
1173+
}
1174+
1175+
ParseResult MakeTensorDescOp::parse(OpAsmParser &parser,
1176+
OperationState &result) {
1177+
// Parse: $base `,` `[` $shape `]` `,` `[` $strides `]`
1178+
// (`,` `descPtr` `=` $descPtr `:` type($descPtr))?
1179+
// attr-dict `:` type($base) `,` type($result)
1180+
1181+
OpAsmParser::UnresolvedOperand base;
1182+
SmallVector<OpAsmParser::UnresolvedOperand> shape;
1183+
SmallVector<OpAsmParser::UnresolvedOperand> strides;
1184+
Type baseType, resultType;
1185+
1186+
// Parse base operand
1187+
if (parser.parseOperand(base) || parser.parseComma())
1188+
return failure();
1189+
1190+
// Parse shape: `[` $shape `]`
1191+
if (parser.parseLSquare() ||
1192+
parser.parseOperandList(shape, OpAsmParser::Delimiter::None) ||
1193+
parser.parseRSquare() || parser.parseComma())
1194+
return failure();
1195+
1196+
// Parse strides: `[` $strides `]`
1197+
if (parser.parseLSquare() ||
1198+
parser.parseOperandList(strides, OpAsmParser::Delimiter::None) ||
1199+
parser.parseRSquare())
1200+
return failure();
1201+
1202+
// Optional descPtr
1203+
OpAsmParser::UnresolvedOperand descPtr;
1204+
Type descPtrType;
1205+
bool hasDescPtr = false;
1206+
1207+
if (succeeded(parser.parseOptionalComma())) {
1208+
if (succeeded(parser.parseOptionalKeyword("descPtr"))) {
1209+
if (parser.parseEqual() || parser.parseOperand(descPtr) ||
1210+
parser.parseColon() || parser.parseType(descPtrType))
1211+
return failure();
1212+
hasDescPtr = true;
1213+
} else {
1214+
// If we see a comma but not "descPtr", it's an error
1215+
return parser.emitError(parser.getCurrentLocation(),
1216+
"expected 'descPtr' keyword");
1217+
}
1218+
}
1219+
1220+
// Attr-dict
1221+
if (parser.parseOptionalAttrDict(result.attributes))
1222+
return failure();
1223+
1224+
// Parse `:` type($base) `,` type($result)
1225+
if (parser.parseColon() || parser.parseType(baseType) ||
1226+
parser.parseComma() || parser.parseType(resultType))
1227+
return failure();
1228+
1229+
// Resolve operands
1230+
if (parser.resolveOperand(base, baseType, result.operands))
1231+
return failure();
1232+
1233+
// Shape operands are I32
1234+
auto i32Type = parser.getBuilder().getI32Type();
1235+
if (parser.resolveOperands(shape, i32Type, result.operands))
1236+
return failure();
1237+
1238+
// Strides operands are I64
1239+
auto i64Type = parser.getBuilder().getI64Type();
1240+
if (parser.resolveOperands(strides, i64Type, result.operands))
1241+
return failure();
1242+
1243+
// Resolve optional descPtr
1244+
if (hasDescPtr) {
1245+
if (parser.resolveOperand(descPtr, descPtrType, result.operands))
1246+
return failure();
1247+
}
1248+
1249+
// Tell MLIR how many operands belong to each segment:
1250+
// [ base, shape..., strides..., descPtr? ]
1251+
SmallVector<int32_t, 4> segmentSizes;
1252+
segmentSizes.push_back(1); // base
1253+
segmentSizes.push_back(shape.size()); // shape (Variadic<I32>)
1254+
segmentSizes.push_back(strides.size()); // strides (Variadic<I64>)
1255+
segmentSizes.push_back(hasDescPtr ? 1 : 0); // descPtr (Optional<TT_Ptr>)
1256+
1257+
auto &builder = parser.getBuilder();
1258+
result.addAttribute("operand_segment_sizes",
1259+
builder.getDenseI32ArrayAttr(segmentSizes));
1260+
1261+
// Result type
1262+
result.addTypes(resultType);
1263+
1264+
return success();
1265+
}
1266+
1267+
void MakeTensorDescOp::print(OpAsmPrinter &p) {
1268+
// Print: $base `,` `[` $shape `]` `,` `[` $strides `]`
1269+
// (`,` `descPtr` `=` $descPtr `:` type($descPtr))?
1270+
// attr-dict `:` type($base) `,` type($result)
1271+
1272+
p << " " << getBase() << ", [" << getShape() << "], [" << getStrides() << "]";
1273+
1274+
// Print descPtr if present
1275+
if (getDescPtr()) {
1276+
p << ", descPtr = " << getDescPtr() << " : " << getDescPtr().getType();
1277+
}
1278+
1279+
// Print attributes (excluding any that were explicitly handled)
1280+
SmallVector<StringRef> elidedAttrs;
1281+
elidedAttrs.push_back("operandSegmentSizes");
1282+
// Elide padding if it's the default value
1283+
if (getPadding() == triton::PaddingOption::PAD_ZERO) {
1284+
elidedAttrs.push_back("padding");
1285+
}
1286+
p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
1287+
1288+
p << " : " << getBase().getType() << ", " << getType();
11531289
}
11541290

11551291
// The following ops, including `call`, `func`, and `return` are copied and

lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,24 @@ class TMACreateDescLowering : public OpRewritePattern<MakeTensorDescOp> {
197197
PatternRewriter &rewriter) const override {
198198
MLIRContext *ctx = op.getContext();
199199
auto loc = op.getLoc();
200-
auto alloc = rewriter.create<triton::gpu::GlobalScratchAllocOp>(
201-
loc, getPointerType(rewriter.getI8Type()), TMA_SIZE_BYTES, TMA_ALIGN);
202-
if (failed(createTMADesc(alloc, op, rewriter))) {
200+
201+
Value descPtr;
202+
// If desc_ptr is provided, use it directly without creating global scratch
203+
if (op.getDescPtr()) {
204+
descPtr = op.getDescPtr();
205+
} else {
206+
// Create global scratch allocation when desc_ptr is not provided
207+
auto alloc = rewriter.create<triton::gpu::GlobalScratchAllocOp>(
208+
loc, getPointerType(rewriter.getI8Type()), TMA_SIZE_BYTES, TMA_ALIGN);
209+
descPtr = alloc.getResult();
210+
}
211+
212+
if (failed(createTMADesc(descPtr, op, rewriter))) {
203213
return failure();
204214
}
205-
rewriter.create<TensormapFenceproxyAcquireOp>(loc, alloc.getResult());
206-
auto newDesc = rewriter.create<ReinterpretTensorDescOp>(loc, op.getType(),
207-
alloc.getResult());
215+
rewriter.create<TensormapFenceproxyAcquireOp>(loc, descPtr);
216+
auto newDesc =
217+
rewriter.create<ReinterpretTensorDescOp>(loc, op.getType(), descPtr);
208218
rewriter.replaceOp(op, newDesc);
209219
return success();
210220
}

python/test/unit/language/test_tlx.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,3 +2425,60 @@ def stoch_round_seed_kernel(x_ptr, y_ptr, seed, BLOCK_SIZE: tl.constexpr):
24252425
different_count = (b1.float() != b2.float()).sum().item()
24262426
assert different_count > SIZE * 0.1, (f"Different seeds should produce different results, "
24272427
f"but only {different_count}/{SIZE} values differ")
2428+
2429+
2430+
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Need Hopper or newer")
2431+
def test_make_tensor_descriptor(device):
2432+
"""Test global_alloc and make_tensor_descriptor together with TMA operations."""
2433+
2434+
def alloc_fn(size: int, align: int, stream: Optional[int]):
2435+
assert align == 128
2436+
assert stream == 0
2437+
return torch.empty(size, dtype=torch.int8, device=device)
2438+
2439+
@triton.jit
2440+
def kernel(input_ptr, output_ptr, SIZE, BLOCK_SIZE: tl.constexpr):
2441+
# Allocate descriptor in global scratch memory using global_alloc
2442+
desc_ptr = tlx.global_alloc(nbytes=256, alignment=128)
2443+
2444+
# Create tensor descriptor using the global scratch pointer
2445+
desc_in = tlx.make_tensor_descriptor(
2446+
desc_ptr=desc_ptr,
2447+
base=input_ptr,
2448+
shape=[SIZE],
2449+
strides=[tl.constexpr(1)],
2450+
block_shape=[BLOCK_SIZE],
2451+
)
2452+
2453+
desc_out = tlx.make_tensor_descriptor(
2454+
desc_ptr=desc_ptr + 128,
2455+
base=output_ptr,
2456+
shape=[SIZE],
2457+
strides=[tl.constexpr(1)],
2458+
block_shape=[BLOCK_SIZE],
2459+
)
2460+
2461+
# Compute tile offset
2462+
pid = tl.program_id(0)
2463+
offset = pid * BLOCK_SIZE
2464+
2465+
# Load and store using standard descriptors
2466+
x = desc_in.load([offset])
2467+
desc_out.store([offset], x)
2468+
2469+
triton.set_allocator(alloc_fn)
2470+
SIZE = 128
2471+
BLOCK_SIZE = 64
2472+
x = torch.ones((SIZE, ), dtype=torch.int16, device=device)
2473+
y = torch.empty_like(x)
2474+
grid = lambda meta: (triton.cdiv(SIZE, BLOCK_SIZE), )
2475+
2476+
compiled_kernel = kernel[grid](x, y, SIZE, BLOCK_SIZE=BLOCK_SIZE)
2477+
2478+
# Check that both global_scratch_alloc and tensormap_create were generated in IR
2479+
ttgir = compiled_kernel.asm["ttgir"]
2480+
assert ttgir.count("ttg.global_scratch_alloc") == 1, "Expected 1 global_scratch_alloc operation"
2481+
assert ttgir.count("ttng.tensormap_create") == 2, "Expected 2 tensormap_create operations"
2482+
2483+
# Verify the data was copied correctly through TMA operations
2484+
torch.testing.assert_close(x, y)

test/Hopper/WarpSpecialization/blackwell_ws_data_partition.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ module attributes {ttg.max_reg_auto_ws = 152 : i32, ttg.min_reg_auto_ws = 24 : i
3737
%cst_2 = arith.constant dense<0.127517432> : tensor<256xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
3838
%cst_3 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #blocked>
3939
// CHECK-COUNT-8: tt.make_tensor_descriptor
40-
%q_desc = tt.make_tensor_descriptor %q, [%c128_i32, %c8192_i32, %c128_i32], [%c1048576_i64, %c128_i64, %c1_i64] : <bf16>, <tensor<1x256x128xbf16, #shared>>
41-
%k_desc = tt.make_tensor_descriptor %k, [%c128_i32, %c8192_i32, %c128_i32], [%c1048576_i64, %c128_i64, %c1_i64] : <bf16>, <tensor<1x128x128xbf16, #shared>>
42-
%v_desc = tt.make_tensor_descriptor %v, [%c128_i32, %c8192_i32, %c128_i32], [%c1048576_i64, %c128_i64, %c1_i64] : <bf16>, <tensor<1x128x128xbf16, #shared>>
43-
%lse_desc_4 = tt.make_tensor_descriptor %lse, [%c128_i32, %c8192_i32], [%lse_desc, %c1_i64] : <f32>, <tensor<1x256xf32, #shared1>>
44-
%o_desc = tt.make_tensor_descriptor %o, [%c128_i32, %c8192_i32, %c128_i32], [%c1048576_i64, %c128_i64, %c1_i64] : <bf16>, <tensor<1x256x128xbf16, #shared>>
40+
%q_desc = tt.make_tensor_descriptor %q, [%c128_i32, %c8192_i32, %c128_i32], [%c1048576_i64, %c128_i64, %c1_i64] : !tt.ptr<bf16>, !tt.tensordesc<tensor<1x256x128xbf16, #shared>>
41+
%k_desc = tt.make_tensor_descriptor %k, [%c128_i32, %c8192_i32, %c128_i32], [%c1048576_i64, %c128_i64, %c1_i64] : !tt.ptr<bf16>, !tt.tensordesc<tensor<1x128x128xbf16, #shared>>
42+
%v_desc = tt.make_tensor_descriptor %v, [%c128_i32, %c8192_i32, %c128_i32], [%c1048576_i64, %c128_i64, %c1_i64] : !tt.ptr<bf16>, !tt.tensordesc<tensor<1x128x128xbf16, #shared>>
43+
%lse_desc_4 = tt.make_tensor_descriptor %lse, [%c128_i32, %c8192_i32], [%lse_desc, %c1_i64] : !tt.ptr<f32>, !tt.tensordesc<tensor<1x256xf32, #shared1>>
44+
%o_desc = tt.make_tensor_descriptor %o, [%c128_i32, %c8192_i32, %c128_i32], [%c1048576_i64, %c128_i64, %c1_i64] : !tt.ptr<bf16>, !tt.tensordesc<tensor<1x256x128xbf16, #shared>>
4545
%0 = tt.get_program_id x : i32
4646
scf.for %virtual_pid = %0 to %total_pids step %c148_i32 : i32 {
4747
%pid_0 = arith.remsi %virtual_pid, %c32_i32 : i32

test/TLX/propagate-layout.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
213213
%0 = tt.get_program_id x : i32
214214
%1 = tt.get_program_id y : i32
215215
%2 = arith.extsi %arg3 : i32 to i64
216-
%3 = tt.make_tensor_descriptor %arg0, [%arg2, %arg3], [%2, %c1_i64] : <i16>, <tensor<64x64xsi16>>
216+
%3 = tt.make_tensor_descriptor %arg0, [%arg2, %arg3], [%2, %c1_i64] : !tt.ptr<i16>, !tt.tensordesc<tensor<64x64xsi16>>
217217
// CHECK: ttg.local_alloc : () -> !ttg.memdesc<1x64x64xi16, #[[$SHARED]], #smem, mutable>
218218
%4 = ttg.local_alloc : () -> !ttg.memdesc<1x64x64xi16, #shared, #smem, mutable>
219219
%5 = ttg.memdesc_index %4[%c0_i32] : !ttg.memdesc<1x64x64xi16, #shared, #smem, mutable> -> !ttg.memdesc<64x64xi16, #shared, #smem, mutable>
@@ -654,8 +654,8 @@ module attributes {tlx.has_explicit_local_mem_access = true, tlx.has_tlx_ops = t
654654
%10 = arith.muli %arg18, %arg15 : i32
655655
%11 = arith.muli %arg16, %c128_i32 : i32
656656
%12 = arith.extsi %11 : i32 to i64
657-
%13 = tt.make_tensor_descriptor %arg2, [%10, %11], [%12, %c1_i64] : <bf16>, <tensor<128x128xbf16>>
658-
%14 = tt.make_tensor_descriptor %arg4, [%10, %11], [%12, %c1_i64] : <bf16>, <tensor<128x128xbf16>>
657+
%13 = tt.make_tensor_descriptor %arg2, [%10, %11], [%12, %c1_i64] : !tt.ptr<bf16>, !tt.tensordesc<tensor<128x128xbf16>>
658+
%14 = tt.make_tensor_descriptor %arg4, [%10, %11], [%12, %c1_i64] : !tt.ptr<bf16>, !tt.tensordesc<tensor<128x128xbf16>>
659659
%15 = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xbf16, #shared, #smem, mutable>
660660
%16 = ttg.local_alloc : () -> !ttg.memdesc<1x128x128xbf16, #shared, #smem, mutable>
661661
%17 = ttg.local_alloc : () -> !ttg.memdesc<3x128x128xbf16, #shared, #smem, mutable>
@@ -784,7 +784,7 @@ module attributes {tlx.has_explicit_local_mem_access = true, tlx.has_tlx_ops = t
784784
%result_4 = ttng.tmem_load %76 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked2>
785785
%77 = tlx.release_layout %result_4 : tensor<128x128xf32, #blocked2> -> tensor<128x128xf32, #blocked3>
786786
ttng.arrive_barrier %45, 1 : !ttg.memdesc<1xi64, #shared1, #smem, mutable>
787-
%78 = tt.make_tensor_descriptor %arg5, [%58, %11], [%12, %c1_i64] : <bf16>, <tensor<128x128xbf16>>
787+
%78 = tt.make_tensor_descriptor %arg5, [%58, %11], [%12, %c1_i64] : !tt.ptr<bf16>, !tt.tensordesc<tensor<128x128xbf16>>
788788
%79 = arith.truncf %77 : tensor<128x128xf32, #blocked3> to tensor<128x128xbf16, #blocked3>
789789
%80 = arith.addi %56, %71 : i32
790790
%81 = arith.trunci %70 : i64 to i32
@@ -875,7 +875,7 @@ module attributes {tlx.has_explicit_local_mem_access = true, tlx.has_tlx_ops = t
875875
}
876876
%76 = arith.muli %arg21, %c128_i32_13 : i32
877877
%77 = arith.extsi %76 : i32 to i64
878-
%78 = tt.make_tensor_descriptor %arg24, [%63, %76], [%77, %c1_i64_6] : <bf16>, <tensor<128x128xbf16>>
878+
%78 = tt.make_tensor_descriptor %arg24, [%63, %76], [%77, %c1_i64_6] : !tt.ptr<bf16>, !tt.tensordesc<tensor<128x128xbf16>>
879879
%79 = ttg.memdesc_index %arg38[%c0_i32_11] : !ttg.memdesc<1x128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
880880
%result_14 = ttng.tmem_load %79 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked2>
881881
%80 = tlx.release_layout %result_14 : tensor<128x128xf32, #blocked2> -> tensor<128x128xf32, #blocked3>
@@ -1071,7 +1071,7 @@ module attributes {tlx.has_explicit_local_mem_access = true, tlx.has_tlx_ops = t
10711071
%75:2 = scf.if %74 -> (i32, i32) {
10721072
%77 = arith.muli %arg21, %c128_i32_7 : i32
10731073
%78 = arith.extsi %77 : i32 to i64
1074-
%79 = tt.make_tensor_descriptor %arg25, [%65, %77], [%78, %c1_i64_6] : <bf16>, <tensor<128x128xbf16>>
1074+
%79 = tt.make_tensor_descriptor %arg25, [%65, %77], [%78, %c1_i64_6] : !tt.ptr<bf16>, !tt.tensordesc<tensor<128x128xbf16>>
10751075
%80 = arith.andi %arg61, %c1_i32_10 : i32
10761076
%81 = ttg.memdesc_index %arg31[%c0_i32_9] : !ttg.memdesc<1xi64, #shared1, #smem, mutable> -> !ttg.memdesc<1xi64, #shared1, #smem, mutable>
10771077
%82 = arith.xori %80, %c1_i32_10 : i32

0 commit comments

Comments
 (0)