Skip to content

Commit 6f41a1d

Browse files
authored
Add a new optimization to version loops containing loads using block ptrs with unknown strides (#5532)
This PR version `scf.for` loops containing tt.load on block pointers, where the block ptr is declared by a make_tensor_ptr operation with no strides equal to 1 (strides unknown at compile time). The versioned loop will then contain tt.load operations that can ber marked as "row-major/column-major" and eventually lowered to efficient 2D block IO loads. --------- Signed-off-by: Ettore Tiotto <[email protected]>
1 parent d3a28a6 commit 6f41a1d

File tree

8 files changed

+437
-0
lines changed

8 files changed

+437
-0
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9797
mlir::triton::intel::registerTritonIntelFuseReshape();
9898
mlir::triton::intel::registerTritonIntelRemoveBoundaryChecks();
9999
mlir::triton::intel::registerTritonIntelRemoveMasks();
100+
mlir::triton::intel::registerTritonIntelStrideVersioning();
100101
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
101102
mlir::triton::registerRelayoutTritonGPUPass();
102103
mlir::triton::gpu::registerAllocateSharedMemoryPass();
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// RUN: triton-opt %s -split-input-file -triton-intel-stride-versioning | FileCheck %s
2+
3+
module {
4+
tt.func public @version_for_loop(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg2: i64, %arg3: i64, %arg4: i64 {tt.divisibility = 16 : i32}, %arg5: i64) {
5+
%c64_i32 = arith.constant 64 : i32
6+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
7+
%c256_i32 = arith.constant 256 : i32
8+
%c32_i32 = arith.constant 32 : i32
9+
%c0_i32 = arith.constant 0 : i32
10+
%c4096_i32 = arith.constant 4096 : i32
11+
%c8192_i32 = arith.constant 8192 : i32
12+
%c4_i32 = arith.constant 4 : i32
13+
%0 = tt.get_program_id x : i32
14+
%1 = arith.divsi %0, %c64_i32 : i32
15+
%2 = arith.muli %1, %c4_i32 : i32
16+
%3 = arith.subi %c32_i32, %2 : i32
17+
%4 = arith.minsi %3, %c4_i32 : i32
18+
%5 = arith.remsi %0, %c64_i32 : i32
19+
%6 = arith.remsi %5, %4 : i32
20+
%7 = arith.addi %2, %6 : i32
21+
%8 = arith.divsi %5, %4 : i32
22+
%9 = arith.extsi %c8192_i32 : i32 to i64
23+
%10 = arith.extsi %c4096_i32 : i32 to i64
24+
%11 = tt.make_tensor_ptr %arg0, [%9, %10], [%arg2, %arg3], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16>>
25+
%12 = tt.make_tensor_ptr %arg1, [%10, %10], [%arg4, %arg5], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
26+
%13 = arith.muli %7, %c256_i32 : i32
27+
%14 = arith.muli %8, %c256_i32 : i32
28+
%15:2 = scf.for %arg9 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 {
29+
%20 = tt.advance %11, [%13, %arg11] : <tensor<256x32xbf16>>
30+
%21 = tt.load %20 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>>
31+
%22 = tt.advance %12, [%arg11, %14] : <tensor<32x256xbf16>>
32+
%23 = tt.load %22 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
33+
%24 = tt.dot %21, %23, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
34+
%25 = arith.addf %arg10, %24 : tensor<256x256xf32>
35+
%26 = arith.addi %arg11, %c32_i32 : i32
36+
scf.yield %25, %26 : tensor<256x256xf32>, i32
37+
}
38+
tt.return
39+
}
40+
41+
// CHECK: tt.func public @version_for_loop
42+
// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64
43+
// CHECK-DAG: [[NEW_PTR1:%.+]] = tt.make_tensor_ptr %arg0, {{.*}}, [%arg2, %c1_i64], {{.*}} {order = array<i32: 1, 0>} : <tensor<256x32xbf16>>
44+
// CHECK-DAG: [[ORIG_PTR1:%.+]] = tt.make_tensor_ptr %arg0, {{.*}}, [%arg2, %arg3], {{.*}} {order = array<i32: 1, 0>} : <tensor<256x32xbf16>>
45+
// CHECK: [[NEW_PTR2:%.+]] = tt.make_tensor_ptr %arg1, {{.*}}, [%arg4, %c1_i64], {{.*}} {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
46+
// CHECK: [[ORIG_PTR2:%.+]] = tt.make_tensor_ptr %arg1, {{.*}}, [%arg4, %arg5], {{.*}} {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
47+
// CHECK-DAG: [[CMP1:%.+]] = arith.cmpi eq, %arg3, [[CST_1_i64]] : i64
48+
// CHECK-DAG: [[CMP2:%.+]] = arith.cmpi eq, %arg5, [[CST_1_i64]] : i64
49+
// CHECK: [[VER_COND:%.+]] = arith.andi [[CMP1]], [[CMP2]] : i1
50+
// CHECK: [[LOOP_VER:%.+]]:2 = scf.if [[VER_COND]]
51+
// CHECK: scf.for
52+
// CHECK: tt.advance [[NEW_PTR1]]
53+
// CHECK: tt.advance [[NEW_PTR2]]
54+
// CHECK: } else {
55+
// CHECK: scf.for
56+
// CHECK: tt.advance [[ORIG_PTR1]]
57+
// CHECK: tt.advance [[ORIG_PTR2]]
58+
// CHECK: }
59+
}
60+
61+
// -----
62+
63+
module {
64+
tt.func public @do_not_version(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg2: i64, %arg3: i64, %arg4: i64 {tt.divisibility = 16 : i32}, %arg5: i64) {
65+
%c64_i32 = arith.constant 64 : i32
66+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
67+
%c256_i32 = arith.constant 256 : i32
68+
%c32_i32 = arith.constant 32 : i32
69+
%c0_i32 = arith.constant 0 : i32
70+
%c4096_i32 = arith.constant 4096 : i32
71+
%c8192_i32 = arith.constant 8192 : i32
72+
%c4_i32 = arith.constant 4 : i32
73+
%c2_i64 = arith.constant 2 : i64
74+
%c4_i64 = arith.constant 4 : i64
75+
%0 = tt.get_program_id x : i32
76+
%1 = arith.divsi %0, %c64_i32 : i32
77+
%2 = arith.muli %1, %c4_i32 : i32
78+
%3 = arith.subi %c32_i32, %2 : i32
79+
%4 = arith.minsi %3, %c4_i32 : i32
80+
%5 = arith.remsi %0, %c64_i32 : i32
81+
%6 = arith.remsi %5, %4 : i32
82+
%7 = arith.addi %2, %6 : i32
83+
%8 = arith.divsi %5, %4 : i32
84+
%9 = arith.extsi %c8192_i32 : i32 to i64
85+
%10 = arith.extsi %c4096_i32 : i32 to i64
86+
%11 = tt.make_tensor_ptr %arg0, [%9, %10], [%c4_i64, %c2_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16>>
87+
%12 = tt.make_tensor_ptr %arg1, [%10, %10], [%c2_i64, %c4_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
88+
%13 = arith.muli %7, %c256_i32 : i32
89+
%14 = arith.muli %8, %c256_i32 : i32
90+
%15:2 = scf.for %arg9 = %c0_i32 to %c4096_i32 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 {
91+
%20 = tt.advance %11, [%13, %arg11] : <tensor<256x32xbf16>>
92+
%21 = tt.load %20 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>>
93+
%22 = tt.advance %12, [%arg11, %14] : <tensor<32x256xbf16>>
94+
%23 = tt.load %22 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
95+
%24 = tt.dot %21, %23, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
96+
%25 = arith.addf %arg10, %24 : tensor<256x256xf32>
97+
%26 = arith.addi %arg11, %c32_i32 : i32
98+
scf.yield %25, %26 : tensor<256x256xf32>, i32
99+
}
100+
tt.return
101+
}
102+
103+
// CHECK: tt.func public @do_not_version
104+
// CHECK-DAG: [[PTR1:%.+]] = tt.make_tensor_ptr %arg0
105+
// CHECK-DAG: [[PTR2:%.+]] = tt.make_tensor_ptr %arg1
106+
// CHECK-NOT: scf.if
107+
// CHECK: scf.for
108+
// CHECK: tt.advance [[PTR1]]
109+
// CHECK: tt.advance [[PTR2]]
110+
}

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def make_ttir(cls, mod, metadata, opt):
213213
passes.common.add_licm(pm)
214214
intel.passes.ttir.add_remove_boundary_checks(pm)
215215
intel.passes.ttir.add_remove_masks(pm)
216+
intel.passes.ttir.add_stride_versioning(pm)
216217
intel.passes.ttir.add_fuse_reshape(pm)
217218
passes.common.add_canonicalizer(pm)
218219
passes.ttir.add_combine(pm)

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,40 @@ def TritonIntelRemoveBoundaryChecks
107107
];
108108
}
109109

110+
def TritonIntelStrideVersioning
111+
: Pass<"triton-intel-stride-versioning", "mlir::ModuleOp"> {
112+
let summary = "Version loops containing block pointers loads if none of them is 1";
113+
114+
let description = [{
115+
This pass versions loops that contain tt.load on a block pointer, if none of the block pointer strides is 1.
116+
For example, given:
117+
118+
%cst = arith.constant ...
119+
%ptr = tt.make_tensor_ptr %base, [%s0, %s1], [%cst, %b], [%x, %y] : <tensor<512x64xf16>>
120+
scf.for ... {
121+
%load = tt.load %ptr : !tt.ptr<tensor<512x64xf16>>
122+
...
123+
}
124+
125+
The transformation creates:
126+
127+
%ptr = tt.make_tensor_ptr %base, [%s0, %s1], [%a, %b], [%x, %y] : <tensor<512x64xf16>>
128+
%ptr' = tt.make_tensor_ptr %base, [%s0, %s1], [%a, 1], [%x, %y] : <tensor<512x64xf16>>
129+
if (%b == 1)
130+
scf.for ... {
131+
%load = tt.load %ptr' : !tt.ptr<tensor<512x64xf16>>
132+
...
133+
}
134+
else
135+
scf.for ... {
136+
%load = tt.load %ptr : !tt.ptr<tensor<512x64xf16>>
137+
...
138+
}
139+
}];
140+
141+
let dependentDialects = [
142+
"mlir::triton::TritonDialect"
143+
];
144+
}
145+
110146
#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES

third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_triton_library(TritonIntelTransforms
22
FuseReshape.cpp
33
RemoveBoundaryChecks.cpp
44
RemoveMasks.cpp
5+
StrideVersioning.cpp
56
TensorDescToBlockPointer.cpp
67

78
DEPENDS

0 commit comments

Comments
 (0)