Skip to content

Commit 9e5605a

Browse files
authored
[tensor_desc_to_block_ptr]: Create pass to transform tt.descriptor_load into tt.load with block ptr (#3756)
Our backend doesn't handle (yet) the `tt.descriptor_load` operation. As a temporary "stop gap" create a new transformation at the Triton language level to replace `tt.descriptor_load` with a `tt.load` and replace its associated `tt.make_tensor_descriptor` operation with a `tt.make_block_ptr` operation. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent dc6e17a commit 9e5605a

File tree

7 files changed

+381
-0
lines changed

7 files changed

+381
-0
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6767
mlir::test::registerTestTritonAMDGPURangeAnalysis();
6868
mlir::triton::registerConvertTritonToTritonGPUPass();
6969
mlir::triton::intel::registerConvertTritonToTritonGPUWarpPass();
70+
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
7071
mlir::triton::intel::registerTritonIntelRemoveMasks();
7172
mlir::triton::intel::registerTritonRaiseBlockPointer();
7273
mlir::triton::gpu::registerAllocateSharedMemoryPass();
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: triton-opt %s -triton-intel-tdesc-to-block-pointer | FileCheck %s
2+
3+
module {
4+
tt.func public @test1(%arg0: !tt.ptr<i16>, %arg1: i32, %arg2: i32) {
5+
%c1_i64 = arith.constant 1 : i64
6+
%c64_i32 = arith.constant 64 : i32
7+
%c8_i32 = arith.constant 8 : i32
8+
%0 = arith.extsi %arg2 : i32 to i64
9+
%1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <i16>, <tensor<8x32xi16>>
10+
%2 = tt.descriptor_load %1[%c8_i32, %c64_i32] : !tt.tensordesc<tensor<8x32xi16>> -> tensor<8x32xi16>
11+
tt.return
12+
}
13+
14+
// CHECK: tt.func public @test1([[PARAM_0:%.+]]: !tt.ptr<i16>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) {
15+
// CHECK-NOT: tt.make_tensor_descriptor
16+
// CHECK-NOT: tt.descriptor_load
17+
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
18+
// CHECK-DAG: [[CST_64_i32:%.+]] = arith.constant 64 : i32
19+
// CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32
20+
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
21+
// CHECK-DAG: [[EXTSI_PARAM_2:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
22+
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2]]], {{\[}}[[EXTSI_PARAM_2]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[CST_64_i32]]] {{.*}} : <tensor<8x32xi16>>
23+
// CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] : !tt.ptr<tensor<8x32xi16>>
24+
// CHECK: tt.return
25+
// CHECK: }
26+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// RUN: triton-opt %s -triton-intel-tdesc-to-block-pointer | FileCheck %s
2+
3+
module {
4+
// COM: Loop containing a tensor descriptor load operation using a loop invariant tensor descriptor.
5+
tt.func public @load_in_loop1(%arg0: !tt.ptr<f16>, %arg1: i32, %arg2: i32) {
6+
%c0 = arith.constant 0 : index
7+
%c1 = arith.constant 1 : index
8+
%c10 = arith.constant 10 : index
9+
%c1_i64 = arith.constant 1 : i64
10+
%c8_i32 = arith.constant 8 : i32
11+
%cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16>
12+
%0 = arith.extsi %arg2 : i32 to i64
13+
%tdesc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f16>, <tensor<16x32xf16>>
14+
%tdesc_out, %sum_out = scf.for %i = %c0 to %c10 step %c1 iter_args(%ptr_iter = %tdesc, %sum_iter = %cst) -> (!tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>) {
15+
%cast_i = arith.index_cast %i : index to i32
16+
%load1 = tt.descriptor_load %ptr_iter[%c8_i32, %cast_i] : !tt.tensordesc<tensor<16x32xf16>> -> tensor<16x32xf16>
17+
%sum_next = arith.addf %sum_iter, %load1 : tensor<16x32xf16>
18+
scf.yield %ptr_iter, %sum_next : !tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>
19+
}
20+
tt.return
21+
}
22+
// CHECK: tt.func public @load_in_loop1([[PARAM_0:%.+]]: !tt.ptr<f16>, [[PARAM_1:%.+]]: i32, [[PARAM_2:%.+]]: i32) {
23+
// CHECK-NOT: tt.make_tensor_descriptor
24+
// CHECK-NOT: tt.descriptor_load
25+
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
26+
// CHECK-DAG: [[CST_8_i32:%.+]] = arith.constant 8 : i32
27+
// CHECK-DAG: [[CST:%.+]] = arith.constant dense<0.000000e+00> : tensor<16x32xf16>
28+
// CHECK-DAG: [[EXTSI_PARAM_2a:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
29+
// CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} iter_args([[VAR_arg1:%.+]] = {{.*}}, [[VAR_arg2:%.+]] = [[CST]]) -> (!tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>) {
30+
// CHECK-DAG: [[IDX_CAST_1:%.+]] = arith.index_cast [[IV]] : index to i32
31+
// CHECK-DAG: [[EXTSI_PARAM_1:%.+]] = arith.extsi [[PARAM_1]] : i32 to i64
32+
// CHECK-DAG: [[EXTSI_PARAM_2b:%.+]] = arith.extsi [[PARAM_2]] : i32 to i64
33+
// CHECK: [[TENSOR_PTR:%.+]] = tt.make_tensor_ptr [[PARAM_0]], {{\[}}[[EXTSI_PARAM_1]], [[EXTSI_PARAM_2b]]], {{\[}}[[EXTSI_PARAM_2a]], [[CST_1_i64]]], {{\[}}[[CST_8_i32]], [[IDX_CAST_1]]] {{.*}} : <tensor<16x32xf16>>
34+
// CHECK: [[LOAD:%.+]] = tt.load [[TENSOR_PTR]] : !tt.ptr<tensor<16x32xf16>>
35+
// CHECK: [[ADD:%.+]] = arith.addf [[VAR_arg2]], [[LOAD]] : tensor<16x32xf16>
36+
// CHECK: scf.yield {{.*}}, [[ADD]] : !tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>
37+
// CHECK: }
38+
// CHECK: tt.return
39+
// CHECK: }
40+
41+
// COM: Loop containing a tensor descriptor load operation using a loop variant tensor descriptor.
42+
tt.func public @load_in_loop2(%arg0: !tt.ptr<f16>, %arg1: i32, %arg2: i32) {
43+
%c0 = arith.constant 0 : index
44+
%c1 = arith.constant 1 : index
45+
%c10 = arith.constant 10 : index
46+
%c1_i64 = arith.constant 1 : i64
47+
%c8_i32 = arith.constant 8 : i32
48+
%cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16>
49+
%0 = arith.extsi %arg2 : i32 to i64
50+
%tdesc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f16>, <tensor<16x32xf16>>
51+
%tdesc_out, %sum_out = scf.for %i = %c0 to %c10 step %c1 iter_args(%ptr_iter = %tdesc, %sum_iter = %cst) -> (!tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>) {
52+
%cast_i = arith.index_cast %i : index to i32
53+
%load1 = tt.descriptor_load %ptr_iter[%c8_i32, %cast_i] : !tt.tensordesc<tensor<16x32xf16>> -> tensor<16x32xf16>
54+
%sum_next = arith.addf %sum_iter, %load1 : tensor<16x32xf16>
55+
%tdesc_in_loop = tt.make_tensor_descriptor %arg0, [%arg2, %arg1], [%c1_i64, %0] : <f16>, <tensor<16x32xf16>>
56+
%cmp = arith.cmpi eq, %cast_i, %c8_i32 : i32
57+
%sel_tdesc = arith.select %cmp, %ptr_iter, %tdesc_in_loop : !tt.tensordesc<tensor<16x32xf16>>
58+
scf.yield %sel_tdesc, %sum_next : !tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>
59+
}
60+
tt.return
61+
}
62+
// CHECK: tt.func public @load_in_loop2({{.*}}) {
63+
// CHECK-NOT: tt.make_tensor_ptr
64+
// CHECK-NOT: tt.load
65+
// CHECK: tt.make_tensor_descriptor
66+
// CHECK: [[FOR_RES:%.+]]:2 = scf.for [[IV:%.+]] = {{.*}} -> (!tt.tensordesc<tensor<16x32xf16>>, tensor<16x32xf16>) {
67+
// CHECK: tt.descriptor_load
68+
// CHECK: tt.make_tensor_descriptor
69+
// CHECK: }
70+
// CHECK: tt.return
71+
// CHECK: }
72+
73+
// COM: Loop yields a tensor descriptor used by a tensor descriptor load.
74+
tt.func public @load_uses_loop_result(%arg0: !tt.ptr<f16>, %arg1: i32, %arg2: i32) {
75+
%c0 = arith.constant 0 : index
76+
%c1 = arith.constant 1 : index
77+
%c10 = arith.constant 10 : index
78+
%c1_i64 = arith.constant 1 : i64
79+
%c8_i32 = arith.constant 8 : i32
80+
%cst = arith.constant dense<0.000000e+00> : tensor<16x32xf16>
81+
%0 = arith.extsi %arg2 : i32 to i64
82+
%tdesc = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <f16>, <tensor<16x32xf16>>
83+
%tdesc_out = scf.for %i = %c0 to %c10 step %c1 iter_args(%ptr_iter = %tdesc) -> (!tt.tensordesc<tensor<16x32xf16>>) {
84+
scf.yield %ptr_iter : !tt.tensordesc<tensor<16x32xf16>>
85+
}
86+
%cast_c10 = arith.index_cast %c10 : index to i32
87+
%load2 = tt.descriptor_load %tdesc_out[%c8_i32, %cast_c10] : !tt.tensordesc<tensor<16x32xf16>> -> tensor<16x32xf16>
88+
tt.return
89+
}
90+
// CHECK: tt.func public @load_uses_loop_result({{.*}}) {
91+
// CHECK-NOT: tt.make_tensor_ptr
92+
// CHECK-NOT: tt.load
93+
// CHECK: tt.make_tensor_descriptor
94+
// CHECK: tt.descriptor_load
95+
// CHECK: tt.return
96+
// CHECK: }
97+
}

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,21 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14+
def TritonIntelTensorDescToBlockPointer
15+
: Pass<"triton-intel-tdesc-to-block-pointer", "mlir::ModuleOp"> {
16+
let summary = "Convert tensor descriptors into block pointers";
17+
18+
let description = [{
19+
This pass attempts to convert tensor descriptors into block pointers.
20+
}];
21+
22+
let dependentDialects = [
23+
"mlir::arith::ArithDialect",
24+
"mlir::scf::SCFDialect",
25+
"mlir::triton::TritonDialect"
26+
];
27+
}
28+
1429
def TritonIntelRemoveMasks
1530
: Pass<"triton-intel-remove-masks", "mlir::ModuleOp"> {
1631
let summary = "Remove masks from tt.load and tt.store operations";

third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_triton_library(TritonIntelTransforms
22
RemoveMasks.cpp
3+
TensorDescToBlockPointer.cpp
34

45
DEPENDS
56
TritonIntelTransformsIncGen

0 commit comments

Comments
 (0)