Skip to content

Commit 4afd331

Browse files
authored
Avoid pass failure for descriptor load/store operations using a tensor descr. allocated on the host (#4812)
When a tensor descriptor is allocated on the host and passed to a triton kernel as a argument, the pass that transforms descriptor load/store operations into equivalent block pointer operations currently fails. This PR rectifies the situation and allow the pass to complete so that subsequently those operations can be lowered to use the 'unwrapped' descriptor argument passed to the kernel. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 2c8f3e2 commit 4afd331

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed
Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
// RUN: triton-opt %s -triton-intel-tdesc-to-block-pointer | FileCheck %s
22

3-
// COM: Test make_tensor_descriptor is not rewritten when it is used by descriptor_gather.
4-
// CHECK-NOT: make_tensor_ptr
5-
// CHECK: tt.make_tensor_descriptor
6-
module {
7-
tt.func public @test_descriptor_gather(%arg0: !tt.ptr<f32>, %arg1: i64, %arg2: tensor<32xi32>, %arg3: i32) {
8-
%c128_i32 = arith.constant 128 : i32
9-
%0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32], [%arg1, %arg1] : <f32>, <tensor<1x32xf32>>
10-
%1 = tt.descriptor_gather %0[%arg2, %arg3] : (!tt.tensordesc<tensor<1x32xf32>>, tensor<32xi32>, i32) -> tensor<32x32xf32>
11-
tt.return
12-
}
3+
// COM: Test that `make_tensor_descriptor` is not rewritten when it is used by `descriptor_gather`.
4+
tt.func public @test_descriptor_gather(%arg0: !tt.ptr<f32>, %arg1: i64, %arg2: tensor<32xi32>, %arg3: i32) {
5+
// CHECK-NOT: make_tensor_ptr
6+
// CHECK: tt.make_tensor_descriptor
7+
%c128_i32 = arith.constant 128 : i32
8+
%0 = tt.make_tensor_descriptor %arg0, [%c128_i32, %c128_i32], [%arg1, %arg1] : <f32>, <tensor<1x32xf32>>
9+
%1 = tt.descriptor_gather %0[%arg2, %arg3] : (!tt.tensordesc<tensor<1x32xf32>>, tensor<32xi32>, i32) -> tensor<32x32xf32>
10+
tt.return
11+
}
12+
13+
// COM: Test that `descriptor_load/descriptor_store` operations are not rewritten if it they use a tensor descriptor function arg.
14+
tt.func public @test_host_descriptor(%desc: !tt.tensordesc<tensor<2x16xf16>>) {
15+
// CHECK: tt.func public @test_host_descriptor([[DESC:%.*]]: !tt.tensordesc<tensor<2x16xf16>>) {
16+
// CHECK: tt.descriptor_load [[DESC]]
17+
// CHECK: tt.descriptor_store [[DESC]]
18+
%c2_i32 = arith.constant 2 : i32
19+
%c32_i32 = arith.constant 32 : i32
20+
%0 = tt.descriptor_load %desc[%c2_i32, %c32_i32] : !tt.tensordesc<tensor<2x16xf16>> -> tensor<2x16xf16>
21+
tt.descriptor_store %desc[%c2_i32, %c32_i32], %0 : !tt.tensordesc<tensor<2x16xf16>>, tensor<2x16xf16>
22+
tt.return
1323
}

third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,18 +232,24 @@ struct TritonIntelTensorDescToBlockPointer
232232
bool> = true>
233233
LogicalResult rewriteDescriptorLoadOrStoreOp(OpTy op) {
234234
assert(op && "Expecting a valid operation");
235+
236+
// At this point we expect to have transformed `make_tensor_descriptor` into
237+
// a `make_block_ptr` operation, except when the tensor descriptor is
238+
// allocated on the host and passed to the kernel as an argument.
239+
Value operand = op.getOperand(0);
240+
if (isa<tt::TensorDescType>(operand.getType()))
241+
return failure();
242+
235243
LLVM_DEBUG(llvm::dbgs() << "Rewriting: " << op << "\n");
236244

237245
OpBuilder builder(op);
238246
Location loc = op.getLoc();
239-
Value ptr = op.getOperand(0);
240-
assert(triton::isTensorPointerType(ptr.getType()) &&
247+
assert(triton::isTensorPointerType(operand.getType()) &&
241248
"Expecting a block ptr");
242-
auto ptrType = cast<tt::PointerType>(ptr.getType());
249+
auto ptrType = cast<tt::PointerType>(operand.getType());
243250
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
244-
245-
ptr =
246-
builder.create<tt::AdvanceOp>(loc, ptr.getType(), ptr, op.getIndices());
251+
Value ptr =
252+
builder.create<tt::AdvanceOp>(loc, ptrType, operand, op.getIndices());
247253

248254
SmallVector<int32_t> boundaryCheck;
249255
for (size_t i = 0; i < tensorType.getRank(); ++i)

0 commit comments

Comments
 (0)