Skip to content

Commit 82d9ce2

Browse files
committed
properly handle while loops
1 parent 13db538 commit 82d9ce2

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

test/TritonIntelGPU/optimize-block-io-encoding.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect --tritonintelgpu-optimize-block-io-encoding | FileCheck %s
22

3+
// COM: test complete example
34
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>
45
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
56
#blocked2 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 16], warpsPerCTA = [16, 2], order = [1, 0]}>
@@ -59,6 +60,57 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.th
5960

6061
// -----
6162

63+
// COM: Test while loop / tt.advance before tt.load (TODO)
64+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
65+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
66+
// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 4], warpsPerCTA = [32, 1], order = [1, 0]}>
67+
// CHECK-DAG: #[[$SUBGROUP_2D_BLOCK:.+]] = #ttig.subgroup_2d_block<{warpsPerCTA = [8, 4], instrShape = [8, 16], numBlocks=2, order=[1, 0], kWidth=1, threadsPerWarp=16}>
68+
// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
69+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.min_sg_size = 16 : i32, ttig.support_dpas, ttig.support_sg_2d_block} {
70+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>) {
71+
%c1024_i64 = arith.constant 1024 : i64
72+
%c5120_i64 = arith.constant 5120 : i64
73+
%c1_i64 = arith.constant 1 : i64
74+
%c256_i32 = arith.constant 256 : i32
75+
%c0_i32 = arith.constant 0 : i32
76+
%c32_i32 = arith.constant 32 : i32
77+
78+
// CHECK: %[[A_PTR:.*]] = tt.make_tensor_ptr %arg0, {{.*}} : <tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
79+
%a_ptr = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%c256_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #blocked1>>
80+
81+
// CHECK: scf.while {{.*}} : (!tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>) -> !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
82+
%1 = scf.while (%a_ptr_crt = %a_ptr) : (!tt.ptr<tensor<256x32xf16, #blocked1>>) -> (!tt.ptr<tensor<256x32xf16, #blocked1>>) {
83+
%2 = "dummy.evaluate_condition"() : () -> i1
84+
// CHECK: scf.condition({{.*}}) {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
85+
scf.condition(%2) %a_ptr_crt : !tt.ptr<tensor<256x32xf16, #blocked1>>
86+
} do {
87+
^bb0(%a_ptr_crt: !tt.ptr<tensor<256x32xf16, #blocked1>>):
88+
// CHECK: ^bb0({{.*}}: !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>):
89+
90+
// CHECK: %[[A_LOAD:.*]] = tt.load {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
91+
%3 = tt.load %a_ptr_crt {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #blocked1>>
92+
// CHECK: ttg.convert_layout %[[A_LOAD]] : tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]> -> tensor<256x32xf16, #[[$BLOCKED]]>
93+
// CHECK: ttg.convert_layout {{.*}} : tensor<256x32xf16, #[[$BLOCKED]]> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>
94+
%4 = ttg.convert_layout %3 : tensor<256x32xf16, #blocked1> -> tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>
95+
96+
%cstB = arith.constant dense<0.000000e+00> : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
97+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
98+
99+
// CHECK: tt.dot {{.*}} : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[$DPAS]]>
100+
%5 = tt.dot %4, %cstB, %cst, inputPrecision = tf32 : tensor<256x32xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
101+
%6 = ttg.convert_layout %5 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked1>
102+
// COM: TODO: support nested tt.advance
103+
// %3 = tt.advance %a_ptr_crt, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #blocked1>>
104+
105+
// CHECK: scf.yield {{.*}} : !tt.ptr<tensor<256x32xf16, #[[$SUBGROUP_2D_BLOCK]]>>
106+
scf.yield %a_ptr_crt : !tt.ptr<tensor<256x32xf16, #blocked1>>
107+
}
108+
tt.return
109+
}
110+
}
111+
112+
// -----
113+
62114
// COM: test complex control flow
63115
// COM: Note that instead of using tt.advance we make a new tensor ptr each time. This is nice, because it lets us test that we can find MakeTensorPtr op inside the scf.if.
64116
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [8, 4], order = [1, 0]}>

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeBlockIOEncoding.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
2626
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
2727
auto iterArg = whileOp.getBeforeArguments()[resultIdx];
2828
auto result = whileOp.getResults()[resultIdx];
29-
auto yieldVal =
30-
whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx);
29+
auto yieldVal = whileOp.getConditionOp().getArgs()[resultIdx];
3130
auto initVal = whileOp.getOperands()[resultIdx];
32-
return {iterArg, result, iterArg, initVal};
31+
auto bodyArg = whileOp.getAfterArguments()[resultIdx];
32+
return {iterArg, result, yieldVal, initVal, bodyArg};
3333
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
3434
SmallVector<Value> values;
3535
for (auto &block : ifOp.getThenRegion().getBlocks()) {
@@ -228,7 +228,10 @@ class TritonIntelGPUOptimizeBlockIOEncodingPass
228228
<< "\nincompatible with Subgroup 2D Block Layout.\n");
229229
return;
230230
}
231+
LLVM_DEBUG(DBGS() << "Retrieving tensor ptr op for ptr " << ptr << "\n");
231232
MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(ptr);
233+
LLVM_DEBUG(DBGS() << "Rerwrite encoding for block ptr op "
234+
<< makeTensorPtrOp << "\n");
232235

233236
auto oldTensorPtrType = cast<PointerType>(makeTensorPtrOp.getType());
234237
auto oldTensorType =

0 commit comments

Comments
 (0)