diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index d3bb95ca95..55a8c1ea5f 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -723,6 +723,7 @@ def TT_ReduceOp: TT_Op<"reduce", ]; let hasVerifier = 1; let hasRegionVerifier = 1; + let hasCanonicalizer = 1; let extraClassDeclaration = [{ llvm::SmallVector getInputTypes(); llvm::SmallVector getElementTypes(); diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index c2c057f42c..74ed007347 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -495,6 +495,57 @@ LogicalResult ReduceOp::verifyRegions() { return verifyRegionsImpl(*this); } +namespace { +/// Replace reduction operations with equivalent reshape operations. +/// +/// This pattern replaces reductions whose input tensor size is 1 in the +/// reduction dimension: +/// ```mlir +/// "tt.reduce"(%0, ...) <{axis = N}> ({...}) +/// : (tensor, ...) -> +/// (tensor, ...) +/// ``` +/// With equivalent reshape operations (one per operand): +/// ```mlir +/// tt.reshape %0 allow_reorder +/// : tensor -> +/// tensor +/// ``` +struct CanonicalizeReshapeReduceOpPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReduceOp reduceOp, + PatternRewriter &rewriter) const final { + Type resultType = reduceOp->getResultTypes().front(); + // `tensor->Ty` case. `tt.reshape` does not support scalar result + // types, so we simply skip this case. + if (!isa(resultType)) + return failure(); + RankedTensorType inputType = reduceOp.getInputTypes().front(); + int32_t axis = reduceOp.getAxis(); + if (inputType.getShape()[axis] != 1) + return failure(); + SmallVector reshapes(reduceOp.getNumOperands()); + llvm::transform( + llvm::zip_equal(reduceOp.getSrcs(), reduceOp->getResultTypes()), + reshapes.begin(), + [loc = reduceOp.getLoc(), &rewriter](auto pair) -> Value { + auto &[value, resultType] = pair; + // Set allow_reorder to support different tensor layouts. + return rewriter.create(loc, resultType, value, + /*allow_reorder=*/true); + }); + rewriter.replaceOp(reduceOp, reshapes); + return success(); + } +}; +} // namespace + +void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + llvm::SmallVector ReduceOp::getInputTypes() { return getInputTypesImpl(this->getOperands()); } diff --git a/test/Triton/canonicalize.mlir b/test/Triton/canonicalize.mlir index 8888271e3c..b8fa2397c1 100644 --- a/test/Triton/canonicalize.mlir +++ b/test/Triton/canonicalize.mlir @@ -50,3 +50,21 @@ tt.func @fn(%arg0: tensor<1xf32, #sliced0>) -> (tensor<32x1xf32, #blocked0>){ tt.return %b : tensor<32x1xf32, #blocked0> } } // end module + +// ----- + +// CHECK-LABEL: tt.func @reduce( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x16xf32>, +// CHECK-SAME: %[[ARG1:.*]]: tensor<2x1x16xf16> +tt.func @reduce(%arg0: tensor<2x1x16xf32>, %arg1: tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) { + // CHECK: %[[VAL0:.*]] = tt.reshape %[[ARG0]] allow_reorder : tensor<2x1x16xf32> -> tensor<2x16xf32> + // CHECK: %[[VAL1:.*]] = tt.reshape %[[ARG1]] allow_reorder : tensor<2x1x16xf16> -> tensor<2x16xf16> + %0:2 = "tt.reduce"(%arg0, %arg1) <{axis=1 : i32}> ({ + ^bb0(%acc0: f32, %acc1: f16, %curr0: f32, %curr1: f16): + %1 = arith.addf %acc0, %curr0 : f32 + %2 = arith.mulf %acc1, %curr1 : f16 + tt.reduce.return %1, %2 : f32, f16 + }) : (tensor<2x1x16xf32>, tensor<2x1x16xf16>) -> (tensor<2x16xf32>, tensor<2x16xf16>) + // CHECK: tt.return %[[VAL0]], %[[VAL1]] : tensor<2x16xf32>, tensor<2x16xf16> + tt.return %0#0, %0#1 : tensor<2x16xf32>, tensor<2x16xf16> +}