Skip to content

Commit c672ff6

Browse files
committed
Set op-folding element limit via a config option
Instead of hard-coding the max number of elements in a folded operation, set the limit via a config option. (The default value matches the old constant's value of 65536.)
1 parent 8d9a84b commit c672ff6

File tree

4 files changed

+40
-24
lines changed

4 files changed

+40
-24
lines changed

stablehlo/transforms/optimization/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#ifndef STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H
1717
#define STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H
1818

19+
#include <cstdint>
1920
#include <memory>
2021

2122
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -42,6 +43,7 @@ void populateStablehloCanonicalizationPatterns(MLIRContext *context,
4243
void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns,
4344
MLIRContext *context,
4445
bool foldFloat,
46+
int64_t foldOpElementLimit,
4547
PatternBenefit benefit = 1);
4648

4749
/// A subset of folding patterns for StableHLO that is necessary for shape

stablehlo/transforms/optimization/Passes.td

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,23 @@ limitations under the License.
1515

1616
include "mlir/Pass/PassBase.td"
1717

18+
defvar stablehlo_aggressive_folder_options = [
19+
Option<"foldFloat", "fold-float", "bool", /*default=*/"true",
20+
"Allow for potentially lossy computations using float type.">,
21+
Option<"foldOpElementLimit", "fold-op-element-limit", "int64_t",
22+
/*default=*/"65536",
23+
"Upper limit on how many elements may be folded by an op folder. "
24+
"This limit doesn't apply in certain special cases such as when "
25+
"adding 0, multiplying by 1, or operating on splats in some ways.">,
26+
];
27+
1828
def StablehloAggressiveFolderPass
1929
: Pass<"stablehlo-aggressive-folder", "func::FuncOp"> {
2030
let summary = "Folds StableHLO operations";
2131
let dependentDialects = [
2232
"mlir::stablehlo::StablehloDialect",
2333
];
24-
let options = [
25-
Option<"foldFloat", "fold-float", "bool", /*default=*/"true",
26-
"Allow for potentially lossy computations using float type.">,
27-
];
34+
let options = stablehlo_aggressive_folder_options;
2835
}
2936

3037
def StablehloAggressiveSimplificationPass
@@ -138,4 +145,5 @@ def StablehloTargetIndependentOptimizationPass
138145
let dependentDialects = [
139146
"mlir::stablehlo::StablehloDialect",
140147
];
148+
let options = stablehlo_aggressive_folder_options;
141149
}

stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,6 @@ namespace stablehlo {
6565

6666
namespace {
6767

68-
// This is an upper limit on how many elements can be folded by an op folder.
69-
// This limit doesn't apply to some special cases like adding a zero,
70-
// multiplying by one, doing many operations with splats.
71-
constexpr int64_t kFoldOpEltLimit = 65536;
72-
7368
// DenseElementsAttr can be constructed from ArrayRef<APInt> but not from
7469
// ArrayRef<APSInt>. This helper bridges the gap.
7570
DenseIntElementsAttr getTensorAttr(ShapedType type, ArrayRef<APSInt> values) {
@@ -425,15 +420,19 @@ struct EvalCompareOpPattern : public OpRewritePattern<CompareOp> {
425420

426421
struct FoldConcatenateOpPattern final
427422
: OpRewritePattern<mlir::stablehlo::ConcatenateOp> {
428-
using OpRewritePattern::OpRewritePattern;
423+
FoldConcatenateOpPattern(MLIRContext* context, int64_t foldOpElementLimit,
424+
PatternBenefit benefit = 1,
425+
ArrayRef<StringRef> generatedNames = {})
426+
: OpRewritePattern(context, benefit, generatedNames),
427+
foldOpElementLimit(foldOpElementLimit) {}
429428

430429
LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op,
431430
PatternRewriter& rewriter) const override {
432431
RankedTensorType type = op.getType();
433432
if (!type.hasStaticShape()) return failure();
434433

435434
size_t numElems = type.getNumElements();
436-
if (numElems > kFoldOpEltLimit) return failure();
435+
if (numElems > foldOpElementLimit) return failure();
437436

438437
// Fold concatenate when all inputs are constants.
439438
OperandRange inputs = op.getInputs();
@@ -463,6 +462,8 @@ struct FoldConcatenateOpPattern final
463462
op, DenseElementsAttr::get(op.getType(), newElems));
464463
return success();
465464
}
465+
466+
int64_t foldOpElementLimit;
466467
};
467468

468469
struct EvalConcatenateOpPattern : public OpRewritePattern<ConcatenateOp> {
@@ -817,13 +818,18 @@ struct FoldSqrtOpPattern : public OpRewritePattern<mlir::stablehlo::SqrtOp> {
817818
};
818819

819820
struct EvalIotaOpPattern : public OpRewritePattern<IotaOp> {
820-
using OpRewritePattern::OpRewritePattern;
821+
EvalIotaOpPattern(MLIRContext* context, int64_t foldOpElementLimit,
822+
PatternBenefit benefit,
823+
ArrayRef<StringRef> generatedNames = {})
824+
: OpRewritePattern(context, benefit, generatedNames),
825+
foldOpElementLimit(foldOpElementLimit) {}
826+
821827
LogicalResult matchAndRewrite(IotaOp op,
822828
PatternRewriter& rewriter) const override {
823829
LLVM_DEBUG(llvm::dbgs() << "EvalIotaOpPattern folding: " << op << '\n');
824830
auto resultType = cast<RankedTensorType>(op.getType());
825831
size_t numElems = resultType.getNumElements();
826-
if (numElems > kFoldOpEltLimit)
832+
if (numElems > foldOpElementLimit)
827833
return rewriter.notifyMatchFailure(op, "too many elements to fold");
828834

829835
auto elementType = resultType.getElementType();
@@ -864,6 +870,8 @@ struct EvalIotaOpPattern : public OpRewritePattern<IotaOp> {
864870
op, DenseIntElementsAttr::get(resultType, values));
865871
return success();
866872
}
873+
874+
int64_t foldOpElementLimit;
867875
};
868876

869877
template <typename RangeType>
@@ -927,7 +935,8 @@ struct StablehloAggressiveFolderPass
927935

928936
LogicalResult initialize(MLIRContext* context) override {
929937
RewritePatternSet patterns_(context);
930-
populateStablehloAggressiveFolderPatterns(&patterns_, context, foldFloat);
938+
populateStablehloAggressiveFolderPatterns(&patterns_, context, foldFloat,
939+
foldOpElementLimit);
931940
patterns = std::move(patterns_);
932941

933942
return success();
@@ -947,16 +956,18 @@ struct StablehloAggressiveFolderPass
947956
void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns,
948957
MLIRContext* context,
949958
bool foldFloat,
959+
int64_t foldOpElementLimit,
950960
PatternBenefit benefit) {
951961
populateStablehloShapeFolderPatterns(patterns, context, foldFloat, benefit);
952-
patterns->add<EvalIotaOpPattern>(context, benefit);
962+
patterns->add<EvalIotaOpPattern>(context, foldOpElementLimit, benefit);
953963
patterns->add<EvalTransposeOpPattern>(context, benefit);
954964

955965
// TODO: Consolidate FoldOp patterns
956966
// One is used by Shape Refinement, the other is a generic folder.
957967
patterns->add<FoldAddOpPattern, FoldBroadcastInDimSplatPattern,
958-
FoldConcatenateOpPattern, FoldMulOpPattern,
959-
FoldSubtractOpPattern, FoldSqrtOpPattern>(context);
968+
FoldMulOpPattern, FoldSubtractOpPattern, FoldSqrtOpPattern>(
969+
context);
970+
patterns->add<FoldConcatenateOpPattern>(context, foldOpElementLimit);
960971
}
961972

962973
void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns,

stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ namespace stablehlo {
3232
#define GEN_PASS_DEF_STABLEHLOTARGETINDEPENDENTOPTIMIZATIONPASS
3333
#include "stablehlo/transforms/optimization/Passes.h.inc"
3434

35-
// This is an upper limit on how many elements can be folded by an op folder.
36-
// This limit doesn't apply to some special cases like adding a zero,
37-
// multiplying by one, doing many operations with splats.
38-
constexpr int64_t kFoldOpEltLimit = 65536;
39-
4035
struct StablehloTargetIndependentOptimizationPass
4136
: public impl::StablehloTargetIndependentOptimizationPassBase<
4237
StablehloTargetIndependentOptimizationPass> {
@@ -45,9 +40,9 @@ struct StablehloTargetIndependentOptimizationPass
4540

4641
LogicalResult initialize(MLIRContext* context) override {
4742
RewritePatternSet patterns_(context);
48-
bool foldFloat = true;
4943
populateStablehloCanonicalizationPatterns(context, &patterns_);
5044
populateStablehloAggressiveFolderPatterns(&patterns_, context, foldFloat,
45+
foldOpElementLimit,
5146
/*benefit=*/2);
5247
patterns = std::move(patterns_);
5348

@@ -58,7 +53,7 @@ struct StablehloTargetIndependentOptimizationPass
5853
GreedyRewriteConfig config;
5954
config.fold = true;
6055
config.cseConstants = true;
61-
config.maxIterations = kFoldOpEltLimit;
56+
config.maxIterations = foldOpElementLimit;
6257
config.useTopDownTraversal = false;
6358
if (failed(applyPatternsGreedily(getOperation(), patterns, config)))
6459
signalPassFailure();

0 commit comments

Comments
 (0)