diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir index 758965d06d..c523967137 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-opt --stablehlo-aggressive-folder --split-input-file --verify-diagnostics %s | FileCheck %s +// RUN: stablehlo-opt --stablehlo-aggressive-folder=fold-op-element-limit=100 --split-input-file --verify-diagnostics %s | FileCheck %s //////// // AddOp @@ -41,6 +41,21 @@ func.func @broadcast_in_dim_fold_splat(%arg0: tensor<3x3xi32>) // ----- +//////// +// ClampOp + +// CHECK-LABEL: func.func @clamp_fold +func.func @clamp_fold(%arg0: tensor<3xi32>) -> tensor<3xi32> { + %min = stablehlo.constant dense<[1, 5, 10]> : tensor<3xi32> + %max = stablehlo.constant dense<[10, 25, 12]> : tensor<3xi32> + %operand = stablehlo.constant dense<[0, 30, 11]> : tensor<3xi32> + // CHECK: stablehlo.constant dense<[1, 25, 11]> : tensor<3xi32> + %0 = "stablehlo.clamp"(%min, %operand, %max) : (tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + func.return %0: tensor<3xi32> +} + +// ----- + //////// // CompareOp @@ -102,6 +117,26 @@ func.func @concatenate_fold() -> (tensor<6xi32>, tensor<3xi32>, tensor<3x3xi32>, // ----- +//////// +// DivOp + +// CHECK-LABEL: @div_fold_cst +func.func @div_fold_cst() -> (tensor, tensor, tensor) { + %cst = stablehlo.constant dense<2> : tensor + %cst_1 = stablehlo.constant dense<2> : tensor + %cst_2 = stablehlo.constant dense<2.0> : tensor + // CHECK: stablehlo.constant dense<1> : tensor + // CHECK: stablehlo.constant dense<1> : tensor + // CHECK: stablehlo.divide{{.*}} : tensor + // DISABLED-CHECK: stablehlo.constant dense<1.0{{.*}}> : tensor + %0 = stablehlo.divide %cst, %cst : tensor + %1 = stablehlo.divide %cst_1, %cst_1 : tensor + %2 = stablehlo.divide %cst_2, %cst_2 : tensor + return %0, %1, %2 : tensor, tensor, tensor +} + +// ----- + //////// // MulOp diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir index 4921a224f4..a1eed9445a 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir @@ -1,4 +1,4 @@ -// RUN: stablehlo-opt --stablehlo-aggressive-simplification --allow-unregistered-dialect --split-input-file %s | FileCheck %s +// RUN: stablehlo-opt --stablehlo-aggressive-simplification=fold-op-element-limit=100 --allow-unregistered-dialect --split-input-file %s | FileCheck %s ///////// // AddOp diff --git a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir index b262bf095e..01b5dd8feb 100644 --- a/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/transforms/stablehlo_refine_shapes.mlir @@ -521,16 +521,16 @@ func.func @eval_slice_zerodim() -> tensor<0x2x1xi64> { // ----- // CHECK-LABEL: func @eval_slice_zerorank -func.func @eval_slice_zerorank() -> tensor { - // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<3.300000e+01> : tensor +func.func @eval_slice_zerorank() -> tensor { + // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<33> : tensor // CHECK: return [[RESULT]] - %0 = stablehlo.constant dense<33.0> : tensor + %0 = stablehlo.constant dense<33> : tensor %1 = "stablehlo.slice"(%0) { start_indices = array, limit_indices = array, strides = array - } : (tensor) -> tensor - func.return %1 : tensor + } : (tensor) -> tensor + func.return %1 : tensor } // ----- diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index 64eaa3c01b..4b585b46da 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include @@ -1038,8 +1039,11 @@ LogicalResult applyShapeRefinementPatterns(func::FuncOp func, // Populate additional patterns for StableHLO extensions. state.addAdditionalPatterns(patterns); + // No float folding and fold as much as possible. Shape refinement will fail + // if int shape computations are unable to be folded. StablehloAggressiveFolderPassOptions folderOptions; folderOptions.optimizeFloat = false; + folderOptions.foldOpElementLimit = std::numeric_limits::max(); // The folding patterns implement partial evaluation of shape computations // which is a critical part of implementing type refinement for ops like diff --git a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp index 9cfac0ba5b..5ba200ffeb 100644 --- a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp +++ b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp @@ -19,6 +19,7 @@ limitations under the License. #include #include #include +#include #include #include "llvm/ADT/APInt.h" @@ -47,6 +48,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" @@ -68,13 +70,6 @@ namespace { static constexpr StablehloAggressiveFolderPassOptions kDefaultOptions; -// DenseElementsAttr can be constructed from ArrayRef but not from -// ArrayRef. This helper bridges the gap. -DenseIntElementsAttr getTensorAttr(ShapedType type, ArrayRef values) { - SmallVector supportedValues(values); - return DenseIntElementsAttr::get(type, supportedValues); -} - APSInt getAPSInt(Type type, uint64_t value) { unsigned numBits; bool isUnsigned; @@ -98,48 +93,106 @@ LogicalResult validateStaticShapeResult(PatternRewriter& rewriter, return success(); } +template +static TypedAttr foldUnaryOpIntOrFloat(Type resultType, TypedAttr operand, + Fn&& folder) { + Type elemTy = getElementTypeOrSelf(operand); + + Attribute res; + if (isa(elemTy)) + res = constFoldUnaryOp( + operand, std::forward(folder)); + else if (isa(elemTy)) + res = constFoldUnaryOp( + operand, std::forward(folder)); + if (!res) return nullptr; + return cast(res); +} + +/// Binary constant folder that used a generic folder function to handle both +/// ints and floats. +template +FailureOr foldUnaryOpIntOrFloat(PatternRewriter& rewriter, + Operation* op, Fn&& folder) { + if (op->getNumOperands() != 1 || op->getNumResults() != 1) + return rewriter.notifyMatchFailure(op, "expected unary op"); + + TypedAttr attr; + matchPattern(op->getOperand(0), m_Constant(&attr)); + + if (!attr) return rewriter.notifyMatchFailure(op, "operand not constants"); + + TypedAttr res = foldUnaryOpIntOrFloat(op->getResultTypes()[0], attr, + std::forward(folder)); + if (!res) return rewriter.notifyMatchFailure(op, "folding failed"); + + return res; +} + /// Binary constant folder that used a generic folder function to handle both /// ints and floats. template -static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs, - Fn&& folder) { +static TypedAttr foldBinaryOpIntOrFloat(Type resultType, TypedAttr lhs, + TypedAttr rhs, Fn&& folder) { Attribute operands[2] = {lhs, rhs}; Type elemTy = getElementTypeOrSelf(lhs); Attribute res; if (isa(elemTy)) - res = constFoldBinaryOp(operands, - folder); - if (isa(elemTy)) - res = constFoldBinaryOp(operands, - folder); - if (res) return cast(res); - - return nullptr; + res = constFoldBinaryOp( + operands, resultType, std::forward(folder)); + else if (isa(elemTy)) + res = constFoldBinaryOp( + operands, resultType, std::forward(folder)); + if (!res) return nullptr; + return cast(res); +} + +/// Binary constant folder that used a generic folder function to handle both +/// ints and floats. +template +FailureOr foldBinaryOpIntOrFloat(PatternRewriter& rewriter, + Operation* op, Fn&& folder) { + if (op->getNumOperands() != 2 || op->getNumResults() != 1) + return rewriter.notifyMatchFailure(op, "expected binary op"); + + TypedAttr lhsAttr, rhsAttr; + matchPattern(op->getOperand(0), m_Constant(&lhsAttr)); + matchPattern(op->getOperand(1), m_Constant(&rhsAttr)); + + if (!lhsAttr || !rhsAttr) + return rewriter.notifyMatchFailure(op, "lhs & rhs operands not constants"); + + TypedAttr res = foldBinaryOpIntOrFloat(op->getResultTypes()[0], lhsAttr, + rhsAttr, std::forward(folder)); + if (!res) return rewriter.notifyMatchFailure(op, "folding failed"); + + return res; } template -LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op, +LogicalResult foldConvertHelper(PatternRewriter& rewriter, OpType op, DenseIntOrFPElementsAttr elements, Type resType, CalculationT&& calculate) { auto result = constFoldCastOp( - elements, resType, calculate); + elements, resType, std::forward(calculate)); - if (!result) + if (!result) { return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { diag << "cast of " << elements.getElementType() << " to " << resType << " failed"; }); + } rewriter.replaceOpWithNewOp(op, result); return success(); } template -LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, +LogicalResult foldConvert(PatternRewriter& rewriter, OpType op, DenseIntOrFPElementsAttr elements, RankedTensorType resultType) { auto oldType = getElementTypeOrSelf(elements); @@ -153,7 +206,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, if (auto newFloatType = dyn_cast(newType)) { // Float -> Float const auto& targetSemantics = newFloatType.getFloatSemantics(); - return evalConvertHelper( + return foldConvertHelper( rewriter, op, elements, resultType, [&targetSemantics](const APFloat& operand, bool& castStatus) { bool losesInfo; @@ -167,7 +220,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, } // Float -> Int - return evalConvertHelper( + return foldConvertHelper( rewriter, op, elements, resultType, [&newBitWidth, &isNewTypeUnsigned](const APFloat& operand, bool& castStatus) { @@ -186,7 +239,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, if (auto newFloatType = dyn_cast(newType)) { // Int -> Float - return evalConvertHelper( + return foldConvertHelper( rewriter, op, elements, resultType, [&newFloatType, &isOldTypeUnsigned](const APInt& operand, bool& /*castStatus*/) { @@ -199,7 +252,7 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, } // Int -> Int - return evalConvertHelper( + return foldConvertHelper( rewriter, op, elements, resultType, [&newBitWidth, &isOldTypeUnsigned](const APInt& operand, bool& /*castStatus*/) { @@ -207,58 +260,6 @@ LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, }); } -// The patterns below implement partial evaluation of shape computations which -// is a critical part of implementing type refinement for ops like -// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape -// depends on the value of their shape operands. - -template -LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op, - FuncType fn) { - auto resultType = op.getType(); - if (failed(validateStaticShapeResult(rewriter, op, resultType))) - return failure(); - - if (!isa(resultType.getElementType())) - return rewriter.notifyMatchFailure(op, - "expected integer result tensor type"); - - SmallVector result; - if constexpr (OpType::template hasTrait()) { - SmallVector operand; - if (failed(hlo::matchInts(op.getOperand(), operand))) - return rewriter.notifyMatchFailure(op, "expected constant operand"); - for (const auto& operandEl : operand) { - result.push_back(fn(operandEl)); - } - } else if constexpr (OpType::template hasTrait< - OpTrait::NOperands<2>::Impl>()) { - SmallVector lhs, rhs; - if (failed(hlo::matchInts(op.getLhs(), lhs)) || - failed(hlo::matchInts(op.getRhs(), rhs))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - for (auto [lhsEl, rhsEl] : llvm::zip(lhs, rhs)) { - result.push_back(fn(lhsEl, rhsEl)); - } - } else if constexpr (OpType::template hasTrait< - OpTrait::NOperands<3>::Impl>()) { - SmallVector x, y, z; - if (failed(hlo::matchInts(op->getOperand(0), x)) || - failed(hlo::matchInts(op->getOperand(1), y)) || - failed(hlo::matchInts(op->getOperand(2), z))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - for (auto [xEl, yEl, zEl] : llvm::zip(x, y, z)) { - result.push_back(fn(xEl, yEl, zEl)); - } - } else { - llvm::report_fatal_error("unsupported number of operands"); - } - - rewriter.replaceOpWithNewOp(op, - getTensorAttr(resultType, result)); - return success(); -} - template struct FoldOpRewritePattern : OpRewritePattern { FoldOpRewritePattern(MLIRContext* context, @@ -275,29 +276,18 @@ struct FoldOpRewritePattern : OpRewritePattern { ArrayRef generatedNames = {}) = delete; const StablehloAggressiveFolderPassOptions& options; -}; -struct FoldAddOpPattern final : FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, - PatternRewriter& rewriter) const override { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - // Pattern: add(cst,cst) -> cst - TypedAttr lhsAttr, rhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - matchPattern(rhs, m_Constant(&rhsAttr)); - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::plus<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } - - return failure(); + LogicalResult validateElementCountForFold(PatternRewriter& rewriter, + Operation* op, + ShapedType resultType) const { + size_t numElems = resultType.getNumElements(); + if (numElems > static_cast(options.foldOpElementLimit)) + return rewriter.notifyMatchFailure( + op, + "too many elements, fold " + "limit is " + + std::to_string(options.foldOpElementLimit)); + return success(); } }; @@ -318,100 +308,102 @@ struct ShapeOpRewritePattern : public FoldOpRewritePattern { } }; -struct EvalAddOpShapePattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldAddOpPattern final + : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; - LogicalResult matchAndRewrite(AddOp op, + LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); + if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) + return failure(); + + auto res = foldBinaryOpIntOrFloat(rewriter, op, std::plus<>{}); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } }; -struct EvalAndOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldAndOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; - LogicalResult matchAndRewrite(AndOp op, + LogicalResult matchAndRewrite(mlir::stablehlo::AndOp op, PatternRewriter& rewriter) const override { + // TODO: Support more int types auto resultType = op.getType(); if (!resultType.getElementType().isInteger(1)) return rewriter.notifyMatchFailure(op, "expected boolean element type"); - return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { - return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0); - }); + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldAnd{}); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } -}; - -// Pattern: broadcast_in_dim(splat, _) -> constant(splat) -struct FoldBroadcastInDimSplatPattern final - : FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, - PatternRewriter& rewriter) const override { - TypedValue operand = op.getOperand(); - if (SplatElementsAttr cstAttr; - matchPattern(operand, m_Constant(&cstAttr))) { - rewriter.replaceOpWithNewOp( - op, SplatElementsAttr::get(op.getType(), - cstAttr.getSplatValue())); - return success(); + struct FoldAnd { + APInt operator()(APInt lhs, APInt rhs) const { + return APInt(lhs.getBitWidth(), !lhs.isZero() && !rhs.isZero()); } - return failure(); - } + std::optional operator()(APFloat lhs, APFloat rhs) const { + return std::nullopt; + } + }; }; -struct EvalBroadcastInDimOpPattern - : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +// Pattern: broadcast_in_dim(splat, _) -> constant(splat) +struct FoldBroadcastInDimOpSplatPattern + : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(BroadcastInDimOp op, PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (failed(validateStaticShapeResult(rewriter, op, resultType))) + if (failed(validateStaticShapeResult(rewriter, op, resultType)) || + failed(validateShapeFoldDtype(rewriter, op, resultType))) return failure(); - auto operandType = op.getOperand().getType(); - if (operandType.getRank() != 0) - return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); - - SmallVector operand; - if (failed(hlo::matchInts(op.getOperand(), operand))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - auto scalar = operand[0]; + SplatElementsAttr cstAttr; + matchPattern(op.getOperand(), m_Constant(&cstAttr)); + if (!cstAttr) return rewriter.notifyMatchFailure(op, "operand not splat"); - rewriter.replaceOpWithNewOp( - op, getTensorAttr(op.getType(), scalar)); + rewriter.replaceOpWithNewOp( + op, SplatElementsAttr::get(op.getType(), + cstAttr.getSplatValue())); return success(); } }; -struct EvalClampOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldCompareOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; - LogicalResult matchAndRewrite(ClampOp op, + LogicalResult matchAndRewrite(CompareOp op, PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt min, APSInt operand, APSInt max) { - if (operand < min) return min; - if (max < operand) return max; - return operand; - }); + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + + auto res = foldBinaryOpIntOrFloat( + rewriter, op, + FoldCompare(op.getComparisonDirection(), op.getCompareType())); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } -}; -struct EvalCompareOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; + struct FoldCompare { + FoldCompare(ComparisonDirection direction, + std::optional kind) + : direction(direction), kind(kind) {} + ComparisonDirection direction; + std::optional kind; - LogicalResult matchAndRewrite(CompareOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - auto kind = op.getCompareType(); - return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) { + // TODO: Enable float folding. + std::optional operator()(APFloat lhs, APFloat rhs) { + return std::nullopt; + } + APInt operator()(APInt lhs, APInt rhs) { bool result = false; - switch (op.getComparisonDirection()) { + switch (direction) { case ComparisonDirection::EQ: result = lhs == rhs; break; @@ -431,9 +423,9 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern { result = kind == ComparisonType::SIGNED ? lhs.slt(rhs) : lhs.ult(rhs); break; } - return getAPSInt(resultType.getElementType(), result); - }); - } + return APInt(/*bitwidth=*/1, result); + } + }; }; ////////////////////////////////// @@ -441,16 +433,15 @@ struct EvalCompareOpPattern : public FoldOpRewritePattern { ///////////////////////////////// struct FoldConcatenateOpPattern final - : FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; + : ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, PatternRewriter& rewriter) const override { RankedTensorType type = op.getType(); - if (!type.hasStaticShape()) return failure(); - - size_t numElems = type.getNumElements(); - if (numElems > static_cast(options.foldOpElementLimit)) + if (failed(validateStaticShapeResult(rewriter, op, type)) || + failed(validateShapeFoldDtype(rewriter, op, type)) || + failed(validateElementCountForFold(rewriter, op, type))) return failure(); // Fold concatenate when all inputs are constants. @@ -466,6 +457,7 @@ struct FoldConcatenateOpPattern final int64_t{1}, std::multiplies<>{}); SmallVector newElems; + size_t numElems = type.getNumElements(); newElems.reserve(numElems); for (int64_t i = 0; i != topSize; ++i) { @@ -485,31 +477,7 @@ struct FoldConcatenateOpPattern final int64_t foldOpElementLimit; }; -struct EvalConcatenateOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; - - LogicalResult matchAndRewrite(ConcatenateOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateStaticShapeResult(rewriter, op, resultType))) - return failure(); - - if (op.getDimension() != 0) - return rewriter.notifyMatchFailure(op, "expected dimension = 0"); - - SmallVector result; - for (Value operand : op->getOperands()) { - if (failed(hlo::matchInts(operand, result))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - } - - rewriter.replaceOpWithNewOp(op, - getTensorAttr(resultType, result)); - return success(); - } -}; - -struct EvalConvertOpPattern : public ShapeOpRewritePattern { +struct FoldConvertOpPattern : public ShapeOpRewritePattern { using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(ConvertOp op, @@ -532,28 +500,50 @@ struct EvalConvertOpPattern : public ShapeOpRewritePattern { return rewriter.notifyMatchFailure( op, "expected constant integer or float operand"); - return evalConvert(rewriter, op, elements, resultType); + return foldConvert(rewriter, op, elements, resultType); } }; -struct EvalDivOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldDivOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(DivOp op, PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs / rhs; }); + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + + bool isUnsignedInt = resultType.getElementType().isUnsignedInteger(); + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldDivide(isUnsignedInt)); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } + + struct FoldDivide { + FoldDivide(bool isUnsignedInt) + : foldIntFn(isUnsignedInt ? foldUint : foldSint) {} + std::function foldIntFn; + + // TODO: Enable float folding. + std::optional operator()(APFloat lhs, APFloat rhs) { + return std::nullopt; // return lhs / rhs; + } + APInt operator()(APInt lhs, APInt rhs) { return foldIntFn(lhs, rhs); } + static APInt foldUint(APInt lhs, APInt rhs) { return lhs.udiv(rhs); } + static APInt foldSint(APInt lhs, APInt rhs) { return lhs.sdiv(rhs); } + }; }; -struct EvalGetDimensionSizeOpPattern - : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldGetDimensionSizeOpPattern + : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(GetDimensionSizeOp op, PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (failed(validateStaticShapeResult(rewriter, op, resultType))) + if (failed(validateStaticShapeResult(rewriter, op, resultType)) || + failed(validateShapeFoldDtype(rewriter, op, resultType))) return failure(); auto operandType = op.getOperand().getType(); @@ -567,86 +557,186 @@ struct EvalGetDimensionSizeOpPattern } }; -struct EvalMaxOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +///// +// Max/Min/Clamp +///// + +struct FoldMax { + FoldMax(bool isUnsignedInt) + : foldIntFn(isUnsignedInt ? foldUint : foldSint) {} + std::function foldIntFn; + + // TODO: Enable float folding. + std::optional operator()(APFloat lhs, APFloat rhs) { + return std::nullopt; // return lhs >= rhs ? lhs : rhs; + } + APInt operator()(APInt lhs, APInt rhs) { return foldIntFn(lhs, rhs); } + static APInt foldUint(APInt lhs, APInt rhs) { + return lhs.uge(rhs) ? lhs : rhs; + } + static APInt foldSint(APInt lhs, APInt rhs) { + return lhs.sge(rhs) ? lhs : rhs; + } +}; + +struct FoldMin { + FoldMin(bool isUnsignedInt) + : foldIntFn(isUnsignedInt ? foldUint : foldSint) {} + std::function foldIntFn; + + // TODO: Enable float folding. + std::optional operator()(APFloat lhs, APFloat rhs) { + return std::nullopt; // return lhs <= rhs ? lhs : rhs; + } + APInt operator()(APInt lhs, APInt rhs) { return foldIntFn(lhs, rhs); } + static APInt foldUint(APInt lhs, APInt rhs) { + return lhs.ule(rhs) ? lhs : rhs; + } + static APInt foldSint(APInt lhs, APInt rhs) { + return lhs.sle(rhs) ? lhs : rhs; + } +}; + +struct FoldMaxOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(MaxOp op, PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { - return lhs >= rhs ? lhs : rhs; - }); + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + + bool isUnsignedInt = resultType.getElementType().isUnsignedInteger(); + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldMax(isUnsignedInt)); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } }; -struct EvalMinOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldMinOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(MinOp op, PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { - return lhs <= rhs ? lhs : rhs; - }); + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + + bool isUnsignedInt = resultType.getElementType().isUnsignedInteger(); + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldMin(isUnsignedInt)); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } }; -struct FoldMulOpPattern final : FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +// Clamp is folded using Min and Max folders. +struct FoldClampOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; - LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, + LogicalResult matchAndRewrite(ClampOp op, PatternRewriter& rewriter) const override { - TypedAttr lhsAttr; - matchPattern(op.getLhs(), m_Constant(&lhsAttr)); - - TypedAttr rhsAttr; - matchPattern(op.getRhs(), m_Constant(&rhsAttr)); - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); - return failure(); + TypedAttr minAttr, operandAttr, maxAttr; + matchPattern(op.getMin(), m_Constant(&minAttr)); + matchPattern(op.getOperand(), m_Constant(&operandAttr)); + matchPattern(op.getMax(), m_Constant(&maxAttr)); + + if (!minAttr || !operandAttr || !maxAttr) + return rewriter.notifyMatchFailure(op, "operands not constant"); + + // Fold clamp using: + // res = max(min, operand) + // res = min(max, res) + bool isUnsignedInt = resultType.getElementType().isUnsignedInteger(); + auto res = foldBinaryOpIntOrFloat(resultType, minAttr, operandAttr, + FoldMax(isUnsignedInt)); + res = foldBinaryOpIntOrFloat(resultType, maxAttr, res, + FoldMin(isUnsignedInt)); + if (!res) return rewriter.notifyMatchFailure(op, "failed to fold clamp"); + rewriter.replaceOpWithNewOp(op, res); + return success(); } }; -struct EvalMulOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldMulOpPattern final : ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; - LogicalResult matchAndRewrite(MulOp op, + LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs * rhs; }); + if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) + return failure(); + + auto res = foldBinaryOpIntOrFloat(rewriter, op, std::multiplies<>{}); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } }; -struct EvalOrOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldOrOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(OrOp op, PatternRewriter& rewriter) const override { + // TODO: Support more int types auto resultType = op.getType(); if (!resultType.getElementType().isInteger(1)) return rewriter.notifyMatchFailure(op, "expected boolean element type"); - return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { - return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); - }); + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldOr{}); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } + + struct FoldOr { + APInt operator()(APInt lhs, APInt rhs) const { + return APInt(lhs.getBitWidth(), !lhs.isZero() || !rhs.isZero()); + } + std::optional operator()(APFloat lhs, APFloat rhs) const { + return std::nullopt; + } + }; }; -struct EvalRemOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldRemOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(RemOp op, PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs % rhs; }); + auto resultType = op.getType(); + if (failed(validateShapeFoldDtype(rewriter, op, resultType))) + return failure(); + + bool isUnsignedInt = resultType.getElementType().isUnsignedInteger(); + auto res = foldBinaryOpIntOrFloat(rewriter, op, FoldRem(isUnsignedInt)); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } + + struct FoldRem { + FoldRem(bool isUnsignedInt) + : foldIntFn(isUnsignedInt ? foldUint : foldSint) {} + std::function foldIntFn; + + // TODO: Enable float folding. + std::optional operator()(APFloat lhs, APFloat rhs) { + return std::nullopt; // return lhs.remainder(rhs); + } + APInt operator()(APInt lhs, APInt rhs) { return foldIntFn(lhs, rhs); } + static APInt foldUint(APInt lhs, APInt rhs) { return lhs.urem(rhs); } + static APInt foldSint(APInt lhs, APInt rhs) { return lhs.srem(rhs); } + }; }; -struct EvalReshapeOpPattern : public ShapeOpRewritePattern { +// Pattern: reshape(cst, shape) -> cst +struct FoldReshapeOpPattern : public ShapeOpRewritePattern { using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(ReshapeOp op, @@ -656,7 +746,6 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern { failed(validateShapeFoldDtype(rewriter, op, resultType))) return failure(); - // Pattern: reshape(cst, shape) -> cst DenseIntOrFPElementsAttr attr; if (!matchPattern(op.getOperand(), m_Constant(&attr))) return rewriter.notifyMatchFailure(op, "expected constant operand"); @@ -665,53 +754,98 @@ struct EvalReshapeOpPattern : public ShapeOpRewritePattern { } }; -struct EvalSelectOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldSelectOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(SelectOp op, PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (failed(validateStaticShapeResult(rewriter, op, resultType))) + if (failed(validateStaticShapeResult(rewriter, op, resultType)) || + failed(validateShapeFoldDtype(rewriter, op, resultType))) return failure(); - SmallVector pred, onTrue, onFalse; - if (failed(hlo::matchInts(op.getPred(), pred)) || - failed(hlo::matchInts(op.getOnTrue(), onTrue)) || - failed(hlo::matchInts(op.getOnFalse(), onFalse))) + DenseIntElementsAttr predAttr; + DenseElementsAttr onTrueAttr, onFalseAttr; + matchPattern(op.getPred(), m_Constant(&predAttr)); + matchPattern(op.getOnTrue(), m_Constant(&onTrueAttr)); + matchPattern(op.getOnFalse(), m_Constant(&onFalseAttr)); + if (!predAttr || !onTrueAttr || !onFalseAttr) return rewriter.notifyMatchFailure(op, "expected constant operands"); - SmallVector result; - for (auto [predEl, onTrueEl, onFalseEl] : - llvm::zip(pred, onTrue, onFalse)) { - result.push_back(predEl != 0 ? onTrueEl : onFalseEl); + // Optimization, handle splat predicate + if (isa(predAttr)) { + auto pred = predAttr.getSplatValue(); + rewriter.replaceOpWithNewOp( + op, pred.isZero() ? onFalseAttr : onTrueAttr); + return success(); } + // TODO: Enable float folding. + if (op.getType().getElementType().isFloat()) + return rewriter.notifyMatchFailure(op, "float select not supported yet"); + + // Fall back to verbose folding + if (failed(validateElementCountForFold(rewriter, op, resultType))) + return failure(); + + SmallVector result; + for (auto [predEl, onTrueEl, onFalseEl] : + llvm::zip(predAttr.getValues(), onTrueAttr.getValues(), + onFalseAttr.getValues())) { + result.push_back(!predEl.isZero() ? onTrueEl : onFalseEl); + } rewriter.replaceOpWithNewOp( - op, getTensorAttr(op.getType(), result)); + op, DenseIntElementsAttr::get(resultType, result)); + return success(); } + + struct FoldSelect { + std::optional operator()(APFloat pred, APFloat onTrue, + APFloat onFalse) { + return std::nullopt; + } + + APInt operator()(APInt pred, APInt onTrue, APInt onFalse) { + return pred != 0 ? onTrue : onFalse; + } + }; }; -struct EvalSignOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldSignOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(SignOp op, PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (!isa(resultType.getElementType())) - return rewriter.notifyMatchFailure(op, - "expected integer result tensor type"); - return evalElementwise(rewriter, op, [&](APSInt operand) { + if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) + return failure(); + + auto elementType = op.getType().getElementType(); + auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSign(elementType)); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); + } + + struct FoldSign { + FoldSign(Type elementType) : elementType(elementType) {} + Type elementType; + // TODO: Enable float folding. + std::optional operator()(APFloat operand) { return std::nullopt; } + + APInt operator()(APInt operand) { + // SignOp only supports signed integers. + APSInt signedInt = getAPSInt(elementType, operand.getSExtValue()); int64_t result; - if (operand.isNegative()) + if (signedInt.isNegative()) result = -1; - else if (operand.isZero()) + else if (signedInt.isZero()) result = 0; else result = 1; - return getAPSInt(resultType.getElementType(), result); - }); - } + return getAPSInt(elementType, result); + } + }; }; template @@ -749,13 +883,14 @@ DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) { ArrayRef(result)); } -struct EvalSliceOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; +struct FoldSliceOpPattern : public ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(SliceOp op, PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (failed(validateStaticShapeResult(rewriter, op, resultType))) + if (failed(validateStaticShapeResult(rewriter, op, resultType)) || + failed(validateShapeFoldDtype(rewriter, op, resultType))) return failure(); auto operand = op.getOperand(); @@ -784,36 +919,18 @@ struct EvalSliceOpPattern : public FoldOpRewritePattern { }; struct FoldSubtractOpPattern final - : FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; + : ShapeOpRewritePattern { + using ShapeOpRewritePattern::ShapeOpRewritePattern; LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, PatternRewriter& rewriter) const override { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - TypedAttr lhsAttr, rhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - matchPattern(rhs, m_Constant(&rhsAttr)); - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::minus<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } - - return failure(); - } -}; - -struct EvalSubtractOpPattern : public FoldOpRewritePattern { - using FoldOpRewritePattern::FoldOpRewritePattern; + if (failed(validateShapeFoldDtype(rewriter, op, op.getType()))) + return failure(); - LogicalResult matchAndRewrite(SubtractOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs - rhs; }); + auto res = foldBinaryOpIntOrFloat(rewriter, op, std::minus<>{}); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); } }; @@ -823,42 +940,36 @@ struct FoldSqrtOpPattern LogicalResult matchAndRewrite(mlir::stablehlo::SqrtOp op, PatternRewriter& rewriter) const final { - TypedAttr lhsAttr; - matchPattern(op.getOperand(), m_Constant(&lhsAttr)); + auto res = foldUnaryOpIntOrFloat(rewriter, op, FoldSqrt()); + if (failed(res)) return failure(); + rewriter.replaceOpWithNewOp(op, res.value()); + return success(); + } - if (!lhsAttr) - return rewriter.notifyMatchFailure(op, "operand not constant"); + struct FoldSqrt { + std::optional operator()(APFloat operand) { + if (operand.getSizeInBits(operand.getSemantics()) == 64) + return APFloat(std::sqrt(operand.convertToDouble())); - if (auto res = constFoldUnaryOp( - lhsAttr, foldSqrt)) { - rewriter.replaceOpWithNewOp( - op, op.getType(), llvm::cast(res)); - return success(); + if (operand.getSizeInBits(operand.getSemantics()) == 32) + return APFloat(sqrtf(operand.convertToFloat())); + return std::nullopt; } - return rewriter.notifyMatchFailure(op, "unable to fold sqrt"); - } - - static std::optional foldSqrt(const APFloat& a) { - if (a.getSizeInBits(a.getSemantics()) == 64) - return APFloat(std::sqrt(a.convertToDouble())); - - if (a.getSizeInBits(a.getSemantics()) == 32) - return APFloat(sqrtf(a.convertToFloat())); - return {}; - } + // TODO: Enable int folding. + std::optional operator()(APInt operand) { return std::nullopt; } + }; }; -struct EvalIotaOpPattern : public FoldOpRewritePattern { +struct FoldIotaOpPattern : public FoldOpRewritePattern { using FoldOpRewritePattern::FoldOpRewritePattern; LogicalResult matchAndRewrite(IotaOp op, PatternRewriter& rewriter) const override { - LLVM_DEBUG(llvm::dbgs() << "EvalIotaOpPattern folding: " << op << '\n'); + LLVM_DEBUG(llvm::dbgs() << "FoldIotaOpPattern folding: " << op << '\n'); auto resultType = cast(op.getType()); - size_t numElems = resultType.getNumElements(); - if (numElems > static_cast(options.foldOpElementLimit)) - return rewriter.notifyMatchFailure(op, "too many elements to fold"); + if (failed(validateElementCountForFold(rewriter, op, resultType))) + return failure(); auto elementType = resultType.getElementType(); @@ -929,7 +1040,7 @@ DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) { // transpose(constant) => constant with permuted dimensions // This covers ranked tensor types with 0 dimensions(zero elements) and 0 // rank(scalar), as well as splat values. -struct EvalTransposeOpPattern : public FoldOpRewritePattern { +struct FoldTransposeOpPattern : public FoldOpRewritePattern { using FoldOpRewritePattern::FoldOpRewritePattern; LogicalResult matchAndRewrite(TransposeOp op, @@ -943,6 +1054,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern { return rewriter.notifyMatchFailure( op, "expected constant integer or float operand"); + // TODO: Does this expand splat values? Should we special case splats? DenseElementsAttr resAttr; if (auto data = els.tryGetValues()) resAttr = transposeType(op, *data); @@ -957,6 +1069,7 @@ struct EvalTransposeOpPattern : public FoldOpRewritePattern { } }; +// TODO: Consider moving this into aggressive simplifications. struct LowerBoolSplatConstantsIntoReduceOpRegion : public FoldOpRewritePattern { using FoldOpRewritePattern::FoldOpRewritePattern; @@ -1160,8 +1273,7 @@ bool hasNoDeclaredSideEffects(Operation* op) { return true; } -struct RemoveDeadWhileOpWithNoSideEffects - : public FoldOpRewritePattern { +struct FoldWhileOpDeadWithNoSideEffects : public FoldOpRewritePattern { using FoldOpRewritePattern::FoldOpRewritePattern; LogicalResult matchAndRewrite(WhileOp op, @@ -1233,23 +1345,16 @@ void populateStablehloAggressiveFolderPatterns( PatternBenefit benefit) { populateStablehloShapeFolderPatterns(context, patterns, options, benefit); - patterns->add(context, options, benefit); - - // TODO: Consolidate FoldOp patterns - // One is used by Shape Refinement, the other is a generic folder. - patterns->add(context, options); + patterns->add(context, options, + benefit); } class StablehloTargetIndependentOptimizationPass { @@ -1266,25 +1371,25 @@ void populateStablehloShapeFolderPatterns( MLIRContext* context, RewritePatternSet* patterns, const StablehloAggressiveFolderPassOptions& options, PatternBenefit benefit) { - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); - patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); + patterns->add(context, options, benefit); } void populateStablehloShapeFolderPatterns(MLIRContext* context, diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td index 8adb6dbc95..3a8edc61e3 100644 --- a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td +++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td @@ -49,6 +49,8 @@ def RankEqual : Constraint< CPred<"llvm::cast($0.getType()).getRank() == llvm::cast($1.getType()).getRank()">, "same rank">; +def TensorDimsAllOne : Constraint, "all tensor dims are 1">; + def TypesEqual : Constraint, "operands are equal">; /////////// @@ -101,6 +103,8 @@ def ZeroExtent : AttrConstraint< CPred<"cast($_self).getNumElements() == 0">, "is zero extent">; +def AnyStaticShapeIntTensor : StaticShapeTensorOf<[HLO_Int]>; + /////////// //// Native Code Call Utilities @@ -503,7 +507,7 @@ def SelectOp_InvertBroadcastPredicateAndSwap // Must be static shape, otherwise would require broadcasting via // CHLO_ConstantLike. def SubtractOp_FoldToZero - : Pat<(StableHLO_SubtractOp AnyStaticShapeTensor:$operand, $operand), + : Pat<(StableHLO_SubtractOp AnyStaticShapeIntTensor:$operand, $operand), (StableHLO_ConstantLike<"0"> $operand)>; // Pattern: subtract(X, 0) -> X