From 5f64bb72043831ca748c871238c4abc85c12a2b0 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Wed, 30 Oct 2024 12:32:27 +0000 Subject: [PATCH 1/4] [Triton] Add canonicalization patterns for `tt.reduce` Add canonicalization pattern replacing `tt.reduce` with axis of size 1 with a `tt.reshape` operation. This may leverage further optimizations and simplifications. Signed-off-by: victor-eds --- include/triton/Dialect/Triton/IR/TritonOps.td | 1 + lib/Dialect/Triton/IR/Ops.cpp | 51 +++++++++++++++++++ test/Triton/canonicalize.mlir | 17 +++++++ 3 files changed, 69 insertions(+) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index d3bb95ca95..55a8c1ea5f 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -723,6 +723,7 @@ def TT_ReduceOp: TT_Op<"reduce", ]; let hasVerifier = 1; let hasRegionVerifier = 1; + let hasCanonicalizer = 1; let extraClassDeclaration = [{ llvm::SmallVector getInputTypes(); llvm::SmallVector getElementTypes(); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index c2c057f42c..e813325a96 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -495,6 +495,57 @@ LogicalResult ReduceOp::verifyRegions() { return verifyRegionsImpl(*this); } +namespace { +/// Replace reduction operations with equivalent reshape operations. +/// +/// This pattern replaces reductions whose input tensor size is 1 in the +/// reduction dimension: +/// ```mlir +/// "tt.reduce"(%0, ...) <{axis = N}> ({...}) +/// : (tensor, ...) -> tensor +/// ``` +/// With equivalent reshape operations (one per operand): +/// ```mlir +/// tt.reshape %0 allow_reorder +/// : tensor -> tensor +/// "tt.reduce"(%0) <{axis = N}> ({...}) +/// {...} +/// : (tensor, ...) -> tensor +/// ``` +struct CanonicalizeReshapeReduceOpPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReduceOp reduceOp, + PatternRewriter &rewriter) const final { + Type resultType = reduceOp->getResultTypes().front(); + // `tensor->Ty` case. `tt.reshape` does not support scalar result + // types, so we simply skip this case. + if (!isa(resultType)) + return failure(); + RankedTensorType inputType = reduceOp.getInputTypes().front(); + int32_t axis = reduceOp.getAxis(); + if (inputType.getShape()[axis] != 1) + return failure(); + SmallVector reshapes(reduceOp.getNumOperands()); + llvm::transform( + llvm::zip_equal(reduceOp.getSrcs(), reduceOp->getResultTypes()), + reshapes.begin(), + [loc = reduceOp.getLoc(), &rewriter](auto pair) -> Value { + auto &[value, resultType] = pair; + return rewriter.create(loc, resultType, value, + /*allow_reorder=*/true); + }); + rewriter.replaceOp(reduceOp, reshapes); + return success(); + } +}; +} // namespace + +void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + llvm::SmallVector ReduceOp::getInputTypes() { return getInputTypesImpl(this->getOperands()); } diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index 8888271e3c..fea19f2b7b 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -50,3 +50,20 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){ tt.return %b : tensor<32x1xf32, #blocked0> } } // end module + +// ----- + +// CHECK-LABEL: tt.func @reduce( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x16xf32>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<2x1x16xf16> +tt.func @reduce(%arg0: tensor<2x1x16xf32>, %arg1: tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) { + // CHECK: tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32> + // CHECK: tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16> + %0:2 = "tt.reduce"(%arg0, %arg1) <{axis=1 : i32}> ({ + ^bb0(%acc0: f32, %acc1: f16, %curr0: f32, %curr1: f16): + %1 = arith.addf %acc0, %curr0 : f32 + %2 = arith.mulf %acc1, %curr1 : f16 + tt.reduce.return %1, %2 : f32, f16 + }) : (tensor<2x1x16xf32>, tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) + tt.return %0#0, %0#1 : tensor<2x16xf32>, tensor<2x16xf16> +} From cc3f0f35a214b4688c76e5c9f135971aa2cf32bf Mon Sep 17 00:00:00 2001 From: victor-eds Date: Wed, 30 Oct 2024 12:39:20 +0000 Subject: [PATCH 2/4] Clean comment --- lib/Dialect/Triton/IR/Ops.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index e813325a96..54f966a475 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -508,9 +508,6 @@ namespace { /// ```mlir /// tt.reshape %0 allow_reorder /// : tensor -> tensor -/// "tt.reduce"(%0) <{axis = N}> ({...}) -/// {...} -/// : (tensor, ...) -> tensor /// ``` struct CanonicalizeReshapeReduceOpPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; From 45f1dbb9d38fa03d4df68791b65ef2c459888afe Mon Sep 17 00:00:00 2001 From: victor-eds Date: Wed, 30 Oct 2024 12:40:58 +0000 Subject: [PATCH 3/4] Explain allow_reorder --- lib/Dialect/Triton/IR/Ops.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 54f966a475..fcdfd0141c 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -529,6 +529,7 @@ struct CanonicalizeReshapeReduceOpPattern final : OpRewritePattern { reshapes.begin(), [loc = reduceOp.getLoc(), &rewriter](auto pair) -> Value { auto &[value, resultType] = pair; + // Set allow_reorder to support different tensor layouts. return rewriter.create(loc, resultType, value, /*allow_reorder=*/true); }); From 6aada16ac72e8fada8eaf6617c5a65b45809ddfc Mon Sep 17 00:00:00 2001 From: victor-eds Date: Wed, 30 Oct 2024 13:54:13 +0000 Subject: [PATCH 4/4] Address review comments --- lib/Dialect/Triton/IR/Ops.cpp | 6 ++++-- test/Triton/canonicalize.mlir | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index fcdfd0141c..74ed007347 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -502,12 +502,14 @@ namespace { /// reduction dimension: /// ```mlir /// "tt.reduce"(%0, ...) <{axis = N}> ({...}) -/// : (tensor, ...) -> tensor +/// : (tensor, ...) -> +/// (tensor, ...) /// ``` /// With equivalent reshape operations (one per operand): /// ```mlir /// tt.reshape %0 allow_reorder -/// : tensor -> tensor +/// : tensor -> +/// tensor /// ``` struct CanonicalizeReshapeReduceOpPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index fea19f2b7b..b8fa2397c1 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -57,13 +57,14 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){ // CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x16xf32>, // CHECK-SAME: %[[ARG1:.*]]: tensor<2x1x16xf16> tt.func @reduce(%arg0: tensor<2x1x16xf32>, %arg1: tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) { - // CHECK: tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32> - // CHECK: tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16> + // CHECK: %[[VAL0:.*]] = tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32> + // CHECK: %[[VAL1:.*]] = tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16> %0:2 = "tt.reduce"(%arg0, %arg1) <{axis=1 : i32}> ({ ^bb0(%acc0: f32, %acc1: f16, %curr0: f32, %curr1: f16): %1 = arith.addf %acc0, %curr0 : f32 %2 = arith.mulf %acc1, %curr1 : f16 tt.reduce.return %1, %2 : f32, f16 }) : (tensor<2x1x16xf32>, tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) + // CHECK: tt.return %[[VAL0]], %[[VAL1]] : tensor<2x16xf32>, tensor<2x16xf16> tt.return %0#0, %0#1 : tensor<2x16xf32>, tensor<2x16xf16> }