File tree Expand file tree Collapse file tree 2 files changed +37
-0
lines changed
lib/Dialect/Triton/Transforms Expand file tree Collapse file tree 2 files changed +37
-0
lines changed Original file line number Diff line number Diff line change @@ -187,6 +187,29 @@ class CombineBroadcastMulReducePattern : public RewritePattern {
187187 }
188188};
189189
190+ // When reducing a 1D tensor the order of elements of the tensor doesn't matter.
191+ // Therefore we can relax the reshape to allow it to re-order elements.
192+ class CombineReshapeReducePatterns : public mlir ::OpRewritePattern<ReshapeOp> {
193+ public:
194+ using OpRewritePattern::OpRewritePattern;
195+
196+ mlir::LogicalResult
197+ matchAndRewrite (triton::ReshapeOp reshapeOp,
198+ mlir::PatternRewriter &rewriter) const override {
199+ if (reshapeOp.getAllowReorder ())
200+ return failure ();
201+ if (reshapeOp.getType ().getRank () != 1 )
202+ return failure ();
203+ for (Operation *user : reshapeOp->getUsers ()) {
204+ if (!isa<triton::ReduceOp, triton::HistogramOp>(user))
205+ return failure ();
206+ }
207+ rewriter.modifyOpInPlace (reshapeOp,
208+ [&]() { reshapeOp.setAllowReorder (true ); });
209+ return success ();
210+ }
211+ };
212+
190213class CombineOpsPass : public TritonCombineOpsBase <CombineOpsPass> {
191214public:
192215 void runOnOperation () override {
@@ -203,6 +226,7 @@ class CombineOpsPass : public TritonCombineOpsBase<CombineOpsPass> {
203226 patterns.add <CombineSelectMaskedLoadPattern>(context);
204227 patterns.add <CombineAddPtrPattern>(context);
205228 patterns.add <CombineBroadcastMulReducePattern>(context);
229+ patterns.add <CombineReshapeReducePatterns>(context);
206230
207231 if (applyPatternsGreedily (m, std::move (patterns)).failed ())
208232 signalPassFailure ();
Original file line number Diff line number Diff line change @@ -345,3 +345,16 @@ tt.func @test_nested_transpose(%arg0: tensor<2x4x8xf32>) -> (tensor<8x2x4xf32>)
345345 // CHECK: tt.return %[[res]]
346346 tt.return %b : tensor <8 x2 x4 xf32 >
347347}
348+
349+ // CHECK-LABEL: test_reshape_reduce
350+ tt.func @test_reshape_reduce (%0: tensor <32 x4 x2 xi32 >) -> (i32 , tensor <16 xi32 >) {
351+ // CHECK: tt.reshape %{{.+}} allow_reorder : tensor<32x4x2xi32> -> tensor<256xi32>
352+ %1 = tt.reshape %0 : tensor <32 x4 x2 xi32 > -> tensor <256 xi32 >
353+ %2 = " tt.reduce" (%1 ) ({
354+ ^bb0 (%arg7: i32 , %arg8: i32 ):
355+ %add = arith.addi %arg7 , %arg8 : i32
356+ tt.reduce.return %add : i32
357+ }) {axis = 0 : i32 } : (tensor <256 xi32 >) -> i32
358+ %3 = tt.histogram %1 : tensor <256 xi32 > -> tensor <16 xi32 >
359+ tt.return %2 , %3 : i32 , tensor <16 xi32 >
360+ }
You can’t perform that action at this time.
0 commit comments