Skip to content

Commit e71689d

Browse files
authored
Properly handle while loop args in OptimizeDescriptorEncoding (#7297)
Fixes the handling of while loop args to ensure the condition op / yield op in the before body is properly handled and to ensure the after body arguments are properly handled. I added a somewhat contrived lit test that reproduces the problem - the second gather + local load is necessary to force a specific layout propagation through the while loop.
1 parent 819cb78 commit e71689d

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,10 +1596,10 @@ SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
15961596
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
15971597
auto iterArg = whileOp.getBeforeArguments()[resultIdx];
15981598
auto result = whileOp.getResults()[resultIdx];
1599-
auto yieldVal =
1600-
whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx);
1599+
auto yieldVal = whileOp.getConditionOp().getArgs()[resultIdx];
16011600
auto initVal = whileOp.getOperands()[resultIdx];
1602-
return {iterArg, result, iterArg, initVal};
1601+
auto bodyArg = whileOp.getAfterArguments()[resultIdx];
1602+
return {iterArg, result, yieldVal, initVal, bodyArg};
16031603
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
16041604
SmallVector<Value> values;
16051605
for (auto &block : ifOp.getThenRegion().getBlocks()) {

test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,43 @@ tt.func public @descriptor_kernel_arg(%arg0: !tt.tensordesc<tensor<64x64xf16>>,
8383
tt.return
8484
}
8585
}
86+
87+
// -----
88+
89+
90+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
91+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
92+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
93+
#smem = #ttg.shared_memory
94+
95+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100"} {
96+
// CHECK-DAG: #[[BLOCKED:.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
97+
// CHECK-DAG: #[[NVMMA_32:.*]] = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8}>
98+
tt.func public @tma_load_while(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: tensor<32xi32, #blocked>, %cond: i1) {
99+
%c1_i32 = arith.constant 1 : i32
100+
%c8_i32 = arith.constant 8 : i32
101+
%c1_i64 = arith.constant 1 : i64
102+
103+
%0 = arith.extsi %arg2 : i32 to i64
104+
// CHECK: tt.make_tensor_descriptor {{.*}} : <i8>, <tensor<1x32xi8, #[[NVMMA_32]]>>
105+
%1 = tt.make_tensor_descriptor %arg0, [%arg1, %arg2], [%0, %c1_i64] : <i8>, <tensor<1x32xi8>>
106+
107+
%2 = scf.while (%arg4 = %1) : (!tt.tensordesc<tensor<1x32xi8>>) -> (!tt.tensordesc<tensor<1x32xi8>>) {
108+
scf.condition(%cond) %arg4 : !tt.tensordesc<tensor<1x32xi8>>
109+
} do {
110+
^bb0(%arg4: !tt.tensordesc<tensor<1x32xi8>>):
111+
// CHECK: ^bb0(%[[ARG4:.*]]: !tt.tensordesc<tensor<1x32xi8, #[[NVMMA_32]]>>):
112+
// CHECK: tt.descriptor_gather %[[ARG4]][{{.*}}] : (!tt.tensordesc<tensor<1x32xi8, #[[NVMMA_32]]>>
113+
%3 = tt.descriptor_gather %arg4[%arg3, %c8_i32] : (!tt.tensordesc<tensor<1x32xi8>>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1>
114+
115+
scf.yield %arg4 : !tt.tensordesc<tensor<1x32xi8>>
116+
}
117+
118+
// CHECK: %[[GATHER:.*]] = tt.descriptor_gather {{.*}} : (!tt.tensordesc<tensor<1x32xi8, #[[NVMMA_32]]>>
119+
%4 = tt.descriptor_gather %1[%arg3, %c8_i32] : (!tt.tensordesc<tensor<1x32xi8>>, tensor<32xi32, #blocked>, i32) -> tensor<32x32xi8, #blocked1>
120+
// CHECK: ttg.local_alloc %[[GATHER]] {{.*}} : (tensor<32x32xi8, #blocked1>) -> !ttg.memdesc<32x32xi8, #[[NVMMA_32]], #smem>
121+
%8 = ttg.local_alloc %4 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (tensor<32x32xi8, #blocked1>) -> !ttg.memdesc<32x32xi8, #shared, #smem>
122+
123+
tt.return
124+
}
125+
}

0 commit comments

Comments
 (0)