Skip to content

Commit 5f64bb7

Browse files
committed
[Triton] Add canonicalization patterns for tt.reduce
Add canonicalization pattern replacing `tt.reduce` with axis of size 1 with a `tt.reshape` operation. This may leverage further optimizations and simplifications. Signed-off-by: victor-eds <[email protected]>
1 parent efce869 commit 5f64bb7

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def TT_ReduceOp: TT_Op<"reduce",
723723
];
724724
let hasVerifier = 1;
725725
let hasRegionVerifier = 1;
726+
let hasCanonicalizer = 1;
726727
let extraClassDeclaration = [{
727728
llvm::SmallVector<RankedTensorType> getInputTypes();
728729
llvm::SmallVector<Type> getElementTypes();

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,57 @@ LogicalResult ReduceOp::verifyRegions() {
495495
return verifyRegionsImpl<ReduceReturnOp>(*this);
496496
}
497497

498+
namespace {
499+
/// Replace reduction operations with equivalent reshape operations.
500+
///
501+
/// This pattern replaces reductions whose input tensor size is 1 in the
502+
/// reduction dimension:
503+
/// ```mlir
504+
/// "tt.reduce"(%0, ...) <{axis = N}> ({...})
505+
/// : (tensor<S0x...xSN-1x1xSN+1x...>, ...) -> tensor<S0x...xSN-1xSN+1x...>
506+
/// ```
507+
/// With equivalent reshape operations (one per operand):
508+
/// ```mlir
509+
/// tt.reshape %0 allow_reorder
510+
/// : tensor<S0x...xSN-1x1xSN+1x...> -> tensor<S0x...xSN-1xSN+1x...>
511+
/// "tt.reduce"(%0) <{axis = N}> ({...})
512+
/// {...}
513+
/// : (tensor<S0x...xSN-1x1xSN+1x...>, ...) -> tensor<S0x...xSN-1xSN+1x...>
514+
/// ```
515+
struct CanonicalizeReshapeReduceOpPattern final : OpRewritePattern<ReduceOp> {
516+
using OpRewritePattern<ReduceOp>::OpRewritePattern;
517+
518+
LogicalResult matchAndRewrite(ReduceOp reduceOp,
519+
PatternRewriter &rewriter) const final {
520+
Type resultType = reduceOp->getResultTypes().front();
521+
// `tensor<NxTy>->Ty` case. `tt.reshape` does not support scalar result
522+
// types, so we simply skip this case.
523+
if (!isa<RankedTensorType>(resultType))
524+
return failure();
525+
RankedTensorType inputType = reduceOp.getInputTypes().front();
526+
int32_t axis = reduceOp.getAxis();
527+
if (inputType.getShape()[axis] != 1)
528+
return failure();
529+
SmallVector<Value> reshapes(reduceOp.getNumOperands());
530+
llvm::transform(
531+
llvm::zip_equal(reduceOp.getSrcs(), reduceOp->getResultTypes()),
532+
reshapes.begin(),
533+
[loc = reduceOp.getLoc(), &rewriter](auto pair) -> Value {
534+
auto &[value, resultType] = pair;
535+
return rewriter.create<ReshapeOp>(loc, resultType, value,
536+
/*allow_reorder=*/true);
537+
});
538+
rewriter.replaceOp(reduceOp, reshapes);
539+
return success();
540+
}
541+
};
542+
} // namespace
543+
544+
void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &results,
545+
MLIRContext *context) {
546+
results.add<CanonicalizeReshapeReduceOpPattern>(context);
547+
}
548+
498549
llvm::SmallVector<RankedTensorType> ReduceOp::getInputTypes() {
499550
return getInputTypesImpl(this->getOperands());
500551
}

test/Triton/canonicalize.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,20 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){
5050
tt.return %b : tensor<32x1xf32, #blocked0>
5151
}
5252
} // end module
53+
54+
// -----
55+
56+
// CHECK-LABEL: tt.func @reduce(
57+
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x16xf32>,
58+
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x1x16xf16>
59+
tt.func @reduce(%arg0: tensor<2x1x16xf32>, %arg1: tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) {
60+
// CHECK: tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32>
61+
// CHECK: tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16>
62+
%0:2 = "tt.reduce"(%arg0, %arg1) <{axis=1 : i32}> ({
63+
^bb0(%acc0: f32, %acc1: f16, %curr0: f32, %curr1: f16):
64+
%1 = arith.addf %acc0, %curr0 : f32
65+
%2 = arith.mulf %acc1, %curr1 : f16
66+
tt.reduce.return %1, %2 : f32, f16
67+
}) : (tensor<2x1x16xf32>, tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>)
68+
tt.return %0#0, %0#1 : tensor<2x16xf32>, tensor<2x16xf16>
69+
}

0 commit comments

Comments
 (0)