Skip to content

Commit c2c193a

Browse files
authored
Optimize reduce(reshape_1D) (#5748)
When reducing a 1D tensor the order of elements doesn't matter. This allows us to use a more relaxed version of reshape.
1 parent c048fcb commit c2c193a

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,29 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
187187
}
188188
};
189189

190+
// When reducing a 1D tensor the order of elements of the tensor doesn't matter.
191+
// Therefore we can relax the reshape to allow it to re-order elements.
192+
class CombineReshapeReducePatterns : public mlir::OpRewritePattern<ReshapeOp> {
193+
public:
194+
using OpRewritePattern::OpRewritePattern;
195+
196+
mlir::LogicalResult
197+
matchAndRewrite(triton::ReshapeOp reshapeOp,
198+
mlir::PatternRewriter &rewriter) const override {
199+
if (reshapeOp.getAllowReorder())
200+
return failure();
201+
if (reshapeOp.getType().getRank() != 1)
202+
return failure();
203+
for (Operation *user : reshapeOp->getUsers()) {
204+
if (!isa<triton::ReduceOp, triton::HistogramOp>(user))
205+
return failure();
206+
}
207+
rewriter.modifyOpInPlace(reshapeOp,
208+
[&]() { reshapeOp.setAllowReorder(true); });
209+
return success();
210+
}
211+
};
212+
190213
class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
191214
public:
192215
void runOnOperation() override {
@@ -203,6 +226,7 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
203226
patterns.add<CombineSelectMaskedLoadPattern>(context);
204227
patterns.add<CombineAddPtrPattern>(context);
205228
patterns.add<CombineBroadcastMulReducePattern>(context);
229+
patterns.add<CombineReshapeReducePatterns>(context);
206230

207231
if (applyPatternsGreedily(m, std::move(patterns)).failed())
208232
signalPassFailure();

test/Triton/combine.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,16 @@ tt.func @test_nested_transpose(%arg0: tensor<2x4x8xf32>) -> (tensor<8x2x4xf32>)
345345
// CHECK: tt.return %[[res]]
346346
tt.return %b : tensor<8x2x4xf32>
347347
}
348+
349+
// CHECK-LABEL: test_reshape_reduce
350+
tt.func @test_reshape_reduce(%0: tensor<32x4x2xi32>) -> (i32, tensor<16xi32>) {
351+
// CHECK: tt.reshape %{{.+}} allow_reorder : tensor<32x4x2xi32> -> tensor<256xi32>
352+
%1 = tt.reshape %0 : tensor<32x4x2xi32> -> tensor<256xi32>
353+
%2 = "tt.reduce" (%1) ({
354+
^bb0(%arg7: i32, %arg8: i32):
355+
%add = arith.addi %arg7, %arg8 : i32
356+
tt.reduce.return %add : i32
357+
}) {axis = 0 : i32} : (tensor<256xi32>) -> i32
358+
%3 = tt.histogram %1 : tensor<256xi32> -> tensor<16xi32>
359+
tt.return %2, %3 : i32, tensor<16xi32>
360+
}

0 commit comments

Comments
 (0)