Skip to content

Commit 19fe7cb

Browse files
Revert "Improve thread locality for reduction ops (#5671)" (#5709)
Reverting due to regressions in internal tests
1 parent 216385e commit 19fe7cb

File tree

5 files changed

+3
-46
lines changed

5 files changed

+3
-46
lines changed

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
700700
}
701701

702702
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
703-
if (getType() == getSrc().getType() && !getAllowReorder()) {
703+
if (getType() == getSrc().getType()) {
704704
// no-op
705705
return getSrc();
706706
}

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ struct CanonicalizeConvertFromReshape
6262

6363
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
6464
return failure();
65-
if (!op.getAllowReorder())
65+
if (!op.getAllowReorder() || op.getEfficientLayout())
6666
return failure();
6767

6868
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(

test/Triton/combine.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ tt.func @test_canonicalize_view(%arg0: tensor<8xf32>, %arg1: tensor<f32>) -> (te
290290
// CHECK: %{{.*}} = tt.splat %arg1 : tensor<f32> -> tensor<2x2x2xf32>
291291
%view2 = tt.reshape %splat allow_reorder : tensor<8xf32> -> tensor<2x2x2xf32>
292292

293-
%view3 = tt.reshape %arg0 : tensor<8xf32> -> tensor<8xf32>
293+
%view3 = tt.reshape %arg0 allow_reorder : tensor<8xf32> -> tensor<8xf32>
294294
// CHECK: %{{.*}} = arith.addf %arg0, %arg0 : tensor<8xf32>
295295
%add = arith.addf %view3, %arg0 : tensor<8xf32>
296296

test/TritonGPU/canonicalize.mlir

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,27 +40,6 @@ tt.func @test_canonicalize_convert_expensive_view(%arg0: tensor<256x16xf32, #blo
4040

4141
// -----
4242

43-
// test that the convert does get combined with the view even if the resulting operation
44-
// is an efficient view.
45-
// CHECK-LABEL: @test_canonicalize_convert_view
46-
// CHECK-SAME: (%[[ARG:.+]]: tensor<64x64xf32
47-
// CHECK-NOT: ttg.convert_layout
48-
// CHECK: %[[V:.+]] = tt.reshape %[[ARG]] allow_reorder
49-
// CHECK: tt.return %[[V]]
50-
#blocked0 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
51-
#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
52-
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [8, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
53-
54-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} {
55-
tt.func @test_canonicalize_convert_view(%arg0: tensor<64x64xf32, #blocked0>) -> tensor<4096xf32, #blocked1> {
56-
%c = ttg.convert_layout %arg0 : tensor<64x64xf32, #blocked0> -> tensor<64x64xf32, #blocked2>
57-
%r = tt.reshape %c allow_reorder efficient_layout : tensor<64x64xf32, #blocked2> -> tensor<4096xf32, #blocked1>
58-
tt.return %r : tensor<4096xf32, #blocked1>
59-
}
60-
} // end module
61-
62-
// -----
63-
6443
// CHECK-LABEL: @test_canonicalize_convert_histogram
6544
// CHECK-SAME: (%[[ARG:.+]]: tensor<256xi32
6645
// CHECK-NOT: ttg.convert_layout

test/TritonGPU/optimize-locality.mlir

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -596,28 +596,6 @@ module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-
596596
}
597597
}
598598

599-
// -----
600-
601-
602-
// CHECK-DAG: #[[$BLOCK0:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}>
603-
// CHECK-DAG: #[[$BLOCK1:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [2, 1], order = [0, 1]}>
604-
// CHECK-LABEL: optimize_view_layout_same_shape
605-
// CHECK: %[[R:.+]] = tt.reshape {{.*}} allow_reorder efficient_layout : tensor<64x16xf32, #[[$BLOCK0]]> -> tensor<64x16xf32, #[[$BLOCK1]]>
606-
// CHECK: %[[C:.+]] = ttg.convert_layout %[[R]] : tensor<64x16xf32, #[[$BLOCK1]]> -> tensor<64x16xf32, #[[$BLOCK0]]>
607-
// CHECK: "tt.reduce"(%[[C]])
608-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}>
609-
module attributes {"ttg.target" = "cuda:80", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 32 : i32} {
610-
tt.func public @optimize_view_layout_same_shape(%arg0: tensor<64x16xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>> {
611-
%0 = tt.reshape %arg0 allow_reorder : tensor<64x16xf32, #blocked> -> tensor<64x16xf32, #blocked>
612-
%1 = "tt.reduce"(%0) <{axis = 1 : i32}> ({
613-
^bb0(%arg1: f32, %arg2: f32):
614-
%2 = arith.maximumf %arg1, %arg2 : f32
615-
tt.reduce.return %2 : f32
616-
}) : (tensor<64x16xf32, #blocked>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
617-
tt.return %1 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
618-
}
619-
}
620-
621599
// -----
622600
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0]}>
623601
#blocked1 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>

0 commit comments

Comments
 (0)