Skip to content

Commit 7efdd03

Browse files
Fix build and test failures from 717997b
Signed-off-by: Whitney Tsang <[email protected]>
1 parent dbccd22 commit 7efdd03

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

test/TritonIntelGPU/coalesce.mlir

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
345345
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
346346
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.threads-per-warp" = 32 : i32} {
347347
// CHECK-DAG: [[BLOCKED_LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
348-
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 32, 1], warpsPerCTA = [1, 1, 16], order = [0, 1, 2]}>
349-
// CHECK-DAG: [[BLOCKED_LAYOUT2:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
348+
// CHECK-DAG: [[BLOCKED_LAYOUT1:#.*]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 1, 32], warpsPerCTA = [1, 4, 4], order = [2, 1, 0]}>
350349
// CHECK: @triton_red_fused_mul_sum_0
351350
tt.func public @triton_red_fused_mul_sum_0(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
352351
%c128_i32 = arith.constant 128 : i32
@@ -368,7 +367,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.th
368367
// CHECK: [[RES:%.*]]:2 = scf.for {{.*}} iter_args([[ARG1:%.*]] = [[PTR1]], [[ARG2:%.*]] = {{.*}}) -> (!tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>)
369368
%8:2 = scf.for %arg5 = %c0_i32 to %c512_i32 step %c128_i32 iter_args(%arg6 = %6, %arg8 = %cst_0) -> (!tt.ptr<tensor<1x32x128xf32, #blocked1>>, tensor<32x128xf32, #blocked>) : i32 {
370369
// CHECK: [[LOAD:%.*]] = tt.load [[ARG1]] evictionPolicy = evict_last {boundaryCheck = array<i32: 2>, padding = 1 : i32} : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
371-
// CHECK-NEXT: ttg.convert_layout [[LOAD]] : tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]> -> tensor<1x32x128xf32, [[BLOCKED_LAYOUT2]]>
372370
%17 = tt.load %arg6 evictionPolicy = evict_last {boundaryCheck = array<i32: 2>, padding = 1 : i32} : !tt.ptr<tensor<1x32x128xf32, #blocked1>>
373371
// CHECK: scf.yield [[ARG1]], [[ARG2]] : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>, tensor<32x128xf32, [[BLOCKED_LAYOUT]]>
374372
scf.yield %arg6, %arg8 : !tt.ptr<tensor<1x32x128xf32, #blocked1>>, tensor<32x128xf32, #blocked>
@@ -404,7 +402,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.th
404402
scf.yield %arg7 : !tt.ptr<tensor<1x32x128xf32, #blocked1>>
405403
}
406404
// CHECK: [[LOAD_RES:%.*]] = tt.load [[RES]] : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
407-
// CHECK: ttg.convert_layout [[LOAD_RES]] : tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]> -> tensor<1x32x128xf32, [[BLOCKED_LAYOUT2]]>
408405
%res = tt.load %8#0 : !tt.ptr<tensor<1x32x128xf32, #blocked1>>
409406
tt.return
410407
}

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct CoalescePass
4848
});
4949

5050
const auto &contiguity = axisInfoAnalysis.getAxisInfo(ptr)->getContiguity();
51-
SmallVector<unsigned> order = argSort(contiguity);
51+
SmallVector<unsigned> order = getOrderFromContiguity(contiguity);
5252
LLVM_DEBUG(llvm::dbgs().indent(2)
5353
<< "order=[" << tt::join(order, ", ") << "]\n";);
5454

@@ -67,8 +67,8 @@ struct CoalescePass
6767
Value val = getMemAccessPtr(use);
6868
if (!val || !matchesShape(val) || memAccessesSameOrder.contains(use))
6969
continue;
70-
auto currOrder =
71-
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
70+
auto currOrder = getOrderFromContiguity(
71+
axisInfoAnalysis.getAxisInfo(val)->getContiguity());
7272
if (order == currOrder) {
7373
LLVM_DEBUG(llvm::dbgs().indent(2)
7474
<< "multi-root-slice: insert to memAccessesSameOrder "

0 commit comments

Comments
 (0)