@@ -65,11 +65,6 @@ namespace stablehlo {
6565
6666namespace {
6767
68- // This is an upper limit on how many elements can be folded by an op folder.
69- // This limit doesn't apply to some special cases like adding a zero,
70- // multiplying by one, doing many operations with splats.
71- constexpr int64_t kFoldOpEltLimit = 65536 ;
72-
7368// DenseElementsAttr can be constructed from ArrayRef<APInt> but not from
7469// ArrayRef<APSInt>. This helper bridges the gap.
7570DenseIntElementsAttr getTensorAttr (ShapedType type, ArrayRef<APSInt> values) {
@@ -425,15 +420,19 @@ struct EvalCompareOpPattern : public OpRewritePattern<CompareOp> {
425420
426421struct FoldConcatenateOpPattern final
427422 : OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
428- using OpRewritePattern::OpRewritePattern;
423+ FoldConcatenateOpPattern (MLIRContext* context, int64_t foldOpElementLimit,
424+ PatternBenefit benefit = 1 ,
425+ ArrayRef<StringRef> generatedNames = {})
426+ : OpRewritePattern(context, benefit, generatedNames),
427+ foldOpElementLimit (foldOpElementLimit) {}
429428
430429 LogicalResult matchAndRewrite (mlir::stablehlo::ConcatenateOp op,
431430 PatternRewriter& rewriter) const override {
432431 RankedTensorType type = op.getType ();
433432 if (!type.hasStaticShape ()) return failure ();
434433
435434 size_t numElems = type.getNumElements ();
436- if (numElems > kFoldOpEltLimit ) return failure ();
435+ if (numElems > foldOpElementLimit ) return failure ();
437436
438437 // Fold concatenate when all inputs are constants.
439438 OperandRange inputs = op.getInputs ();
@@ -463,6 +462,8 @@ struct FoldConcatenateOpPattern final
463462 op, DenseElementsAttr::get (op.getType (), newElems));
464463 return success ();
465464 }
465+
466+ int64_t foldOpElementLimit;
466467};
467468
468469struct EvalConcatenateOpPattern : public OpRewritePattern <ConcatenateOp> {
@@ -817,13 +818,18 @@ struct FoldSqrtOpPattern : public OpRewritePattern<mlir::stablehlo::SqrtOp> {
817818};
818819
819820struct EvalIotaOpPattern : public OpRewritePattern <IotaOp> {
820- using OpRewritePattern::OpRewritePattern;
821+ EvalIotaOpPattern (MLIRContext* context, int64_t foldOpElementLimit,
822+ PatternBenefit benefit,
823+ ArrayRef<StringRef> generatedNames = {})
824+ : OpRewritePattern(context, benefit, generatedNames),
825+ foldOpElementLimit (foldOpElementLimit) {}
826+
821827 LogicalResult matchAndRewrite (IotaOp op,
822828 PatternRewriter& rewriter) const override {
823829 LLVM_DEBUG (llvm::dbgs () << " EvalIotaOpPattern folding: " << op << ' \n ' );
824830 auto resultType = cast<RankedTensorType>(op.getType ());
825831 size_t numElems = resultType.getNumElements ();
826- if (numElems > kFoldOpEltLimit )
832+ if (numElems > foldOpElementLimit )
827833 return rewriter.notifyMatchFailure (op, " too many elements to fold" );
828834
829835 auto elementType = resultType.getElementType ();
@@ -864,6 +870,8 @@ struct EvalIotaOpPattern : public OpRewritePattern<IotaOp> {
864870 op, DenseIntElementsAttr::get (resultType, values));
865871 return success ();
866872 }
873+
874+ int64_t foldOpElementLimit;
867875};
868876
869877template <typename RangeType>
@@ -927,7 +935,8 @@ struct StablehloAggressiveFolderPass
927935
928936 LogicalResult initialize (MLIRContext* context) override {
929937 RewritePatternSet patterns_ (context);
930- populateStablehloAggressiveFolderPatterns (&patterns_, context, foldFloat);
938+ populateStablehloAggressiveFolderPatterns (&patterns_, context, foldFloat,
939+ foldOpElementLimit);
931940 patterns = std::move (patterns_);
932941
933942 return success ();
@@ -947,16 +956,18 @@ struct StablehloAggressiveFolderPass
947956void populateStablehloAggressiveFolderPatterns (RewritePatternSet* patterns,
948957 MLIRContext* context,
949958 bool foldFloat,
959+ int64_t foldOpElementLimit,
950960 PatternBenefit benefit) {
951961 populateStablehloShapeFolderPatterns (patterns, context, foldFloat, benefit);
952- patterns->add <EvalIotaOpPattern>(context, benefit);
962+ patterns->add <EvalIotaOpPattern>(context, foldOpElementLimit, benefit);
953963 patterns->add <EvalTransposeOpPattern>(context, benefit);
954964
955965 // TODO: Consolidate FoldOp patterns
956966 // One is used by Shape Refinement, the other is a generic folder.
957967 patterns->add <FoldAddOpPattern, FoldBroadcastInDimSplatPattern,
958- FoldConcatenateOpPattern, FoldMulOpPattern,
959- FoldSubtractOpPattern, FoldSqrtOpPattern>(context);
968+ FoldMulOpPattern, FoldSubtractOpPattern, FoldSqrtOpPattern>(
969+ context);
970+ patterns->add <FoldConcatenateOpPattern>(context, foldOpElementLimit);
960971}
961972
962973void populateStablehloShapeFolderPatterns (RewritePatternSet* patterns,
0 commit comments