Skip to content

Commit 0f7a522

Browse files
authored
Add more MHLO folders to StableHLO (#2753)
Some of these may seem unsafe (simplifying float `X*1.0` or `X*0.0` may not account for NaNs properly), but both MHLO and XLA simplifications do this optim so probably generally safe.
1 parent 8993ece commit 0f7a522

File tree

5 files changed

+121
-52
lines changed

5 files changed

+121
-52
lines changed

stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,17 @@ func.func @eval_convert_f64_precision_loss() -> (tensor<1xf32>, tensor<f32>) {
269269

270270
// -----
271271

272+
// CHECK-LABEL: func @fold_sqrt
273+
func.func @fold_sqrt() -> (tensor<f32>) {
274+
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<2.0{{.*}}> : tensor<f32>
275+
// CHECK: return [[RESULT0]]
276+
%0 = stablehlo.constant dense<4.0> : tensor<f32>
277+
%1 = stablehlo.sqrt %0 : tensor<f32>
278+
func.return %1 : tensor<f32>
279+
}
280+
281+
// -----
282+
272283
// CHECK-LABEL: func @eval_transpose
273284
func.func @eval_transpose() -> (tensor<2x3x2xi32>, tensor<2x4x3xi32>, tensor<4x3x2xi32>) {
274285
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<

stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,23 @@ func.func @multiply_by_one(%arg0: tensor<i32>) -> tensor<i32> {
899899
return %0 : tensor<i32>
900900
}
901901

902+
// CHECK-LABEL: @multiply_by_zero_float
903+
func.func @multiply_by_zero_float(%arg0: tensor<f32>) -> tensor<f32> {
904+
%cst = stablehlo.constant dense<0.0> : tensor<f32>
905+
// CHECK: stablehlo.constant dense<0.0{{.*}}> : tensor<f32>
906+
%0 = stablehlo.multiply %cst, %arg0 : tensor<f32>
907+
return %0 : tensor<f32>
908+
}
909+
910+
// CHECK-LABEL: @multiply_by_one_float
911+
func.func @multiply_by_one_float(%arg0: tensor<f32>) -> tensor<f32> {
912+
%cst = stablehlo.constant dense<1.0> : tensor<f32>
913+
%0 = stablehlo.multiply %cst, %arg0 : tensor<f32>
914+
// CHECK-NOT: stablehlo.constant
915+
// CHECK: return %arg0 : tensor<f32>
916+
return %0 : tensor<f32>
917+
}
918+
902919
// -----
903920

904921
/////////

stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp

Lines changed: 86 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ See the License for the specific language governing permissions and
1212
limitations under the License.
1313
==============================================================================*/
1414

15+
#include <cmath>
1516
#include <cstddef>
1617
#include <cstdint>
18+
#include <functional>
1719
#include <numeric>
20+
#include <optional>
1821
#include <utility>
1922

2023
#include "llvm/ADT/APInt.h"
@@ -23,11 +26,14 @@ limitations under the License.
2326
#include "llvm/ADT/STLExtras.h"
2427
#include "llvm/ADT/SmallVector.h"
2528
#include "llvm/ADT/StringRef.h"
29+
#include "llvm/Support/Casting.h"
2630
#include "llvm/Support/Debug.h"
2731
#include "llvm/Support/ErrorHandling.h"
32+
#include "llvm/Support/LogicalResult.h"
2833
#include "mlir/Dialect/CommonFolders.h"
2934
#include "mlir/Dialect/Func/IR/FuncOps.h"
3035
#include "mlir/Dialect/Utils/IndexingUtils.h"
36+
#include "mlir/IR/BuiltinAttributeInterfaces.h"
3137
#include "mlir/IR/BuiltinAttributes.h"
3238
#include "mlir/IR/BuiltinTypeInterfaces.h"
3339
#include "mlir/IR/BuiltinTypes.h"
@@ -86,7 +92,7 @@ APSInt getAPSInt(Type type, uint64_t value) {
8692
/*isUnsigned=*/isUnsigned);
8793
}
8894

89-
LogicalResult validateResultTypeForEval(PatternRewriter& rewriter,
95+
LogicalResult validateStaticShapeResult(PatternRewriter& rewriter,
9096
Operation* op, ShapedType resultType) {
9197
if (!resultType.hasStaticShape())
9298
return rewriter.notifyMatchFailure(
@@ -212,7 +218,7 @@ template <typename OpType, typename FuncType>
212218
LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op,
213219
FuncType fn) {
214220
auto resultType = op.getType();
215-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
221+
if (failed(validateStaticShapeResult(rewriter, op, resultType)))
216222
return failure();
217223

218224
if (!isa<IntegerType>(resultType.getElementType()))
@@ -279,6 +285,28 @@ struct FoldAddOpPattern final : OpRewritePattern<mlir::stablehlo::AddOp> {
279285
}
280286
};
281287

288+
// A base class to use for patterns that may be used for integer shape math,
289+
// but also may be used for general folding of floats.
290+
template <typename OpType>
291+
struct ShapeOpRewritePattern : public OpRewritePattern<OpType> {
292+
ShapeOpRewritePattern(MLIRContext* context, PatternBenefit benefit,
293+
bool foldFloat_)
294+
: OpRewritePattern<OpType>(context, benefit), foldFloat{foldFloat_} {}
295+
296+
using OpRewritePattern<OpType>::OpRewritePattern;
297+
using OpRewritePattern<OpType>::matchAndRewrite;
298+
299+
LogicalResult validateShapeFoldDtype(PatternRewriter& rewriter, OpType op,
300+
ShapedType resultType) const {
301+
if (resultType.getElementType().isInteger()) return success();
302+
if (foldFloat && isa<FloatType>(resultType.getElementType()))
303+
return success();
304+
return rewriter.notifyMatchFailure(op, "skipping fold of shape op dtype");
305+
}
306+
307+
bool foldFloat;
308+
};
309+
282310
struct EvalAddOpShapePattern : public OpRewritePattern<AddOp> {
283311
using OpRewritePattern::OpRewritePattern;
284312
LogicalResult matchAndRewrite(AddOp op,
@@ -327,7 +355,7 @@ struct EvalBroadcastInDimOpPattern : public OpRewritePattern<BroadcastInDimOp> {
327355
LogicalResult matchAndRewrite(BroadcastInDimOp op,
328356
PatternRewriter& rewriter) const override {
329357
auto resultType = op.getType();
330-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
358+
if (failed(validateStaticShapeResult(rewriter, op, resultType)))
331359
return failure();
332360

333361
auto operandType = op.getOperand().getType();
@@ -442,7 +470,7 @@ struct EvalConcatenateOpPattern : public OpRewritePattern<ConcatenateOp> {
442470
LogicalResult matchAndRewrite(ConcatenateOp op,
443471
PatternRewriter& rewriter) const override {
444472
auto resultType = op.getType();
445-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
473+
if (failed(validateStaticShapeResult(rewriter, op, resultType)))
446474
return failure();
447475

448476
if (op.getDimension() != 0)
@@ -460,31 +488,23 @@ struct EvalConcatenateOpPattern : public OpRewritePattern<ConcatenateOp> {
460488
}
461489
};
462490

463-
struct EvalConvertOpPattern : public OpRewritePattern<ConvertOp> {
464-
using OpRewritePattern::OpRewritePattern;
465-
466-
EvalConvertOpPattern(MLIRContext* context, PatternBenefit benefit,
467-
bool foldFloat_)
468-
: OpRewritePattern<ConvertOp>(context, benefit), foldFloat{foldFloat_} {}
491+
struct EvalConvertOpPattern : public ShapeOpRewritePattern<ConvertOp> {
492+
using ShapeOpRewritePattern::ShapeOpRewritePattern;
469493

470494
LogicalResult matchAndRewrite(ConvertOp op,
471495
PatternRewriter& rewriter) const override {
472496
auto operand = op.getOperand();
473497
RankedTensorType resultType = op.getType();
474498

475-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
499+
if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
500+
failed(validateShapeFoldDtype(rewriter, op, resultType)))
476501
return failure();
477502

478503
auto operandElemType = getElementTypeOrSelf(operand.getType());
479504
auto resultElemType = getElementTypeOrSelf(resultType);
480-
if (!(operandElemType.isInteger() && resultElemType.isInteger()) &&
481-
!foldFloat)
482-
return rewriter.notifyMatchFailure(op,
483-
"lossy computations are not allowed");
484-
485-
if (!resultElemType.isIntOrFloat())
486-
return rewriter.notifyMatchFailure(
487-
op, "expected integer or float result tensor type");
505+
if (!foldFloat &&
506+
(isa<FloatType>(operandElemType) || isa<FloatType>(resultElemType)))
507+
return rewriter.notifyMatchFailure(op, "skipping fold of float convert");
488508

489509
DenseIntOrFPElementsAttr elements;
490510
if (!matchPattern(operand, m_Constant(&elements)))
@@ -493,9 +513,6 @@ struct EvalConvertOpPattern : public OpRewritePattern<ConvertOp> {
493513

494514
return evalConvert(rewriter, op, elements, resultType);
495515
}
496-
497-
private:
498-
bool foldFloat;
499516
};
500517

501518
struct EvalDivOpPattern : public OpRewritePattern<DivOp> {
@@ -513,7 +530,7 @@ struct EvalGetDimensionSizeOpPattern
513530
LogicalResult matchAndRewrite(GetDimensionSizeOp op,
514531
PatternRewriter& rewriter) const override {
515532
auto resultType = op.getType();
516-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
533+
if (failed(validateStaticShapeResult(rewriter, op, resultType)))
517534
return failure();
518535

519536
auto operandType = op.getOperand().getType();
@@ -552,23 +569,11 @@ struct FoldMulOpPattern final : OpRewritePattern<mlir::stablehlo::MulOp> {
552569

553570
LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op,
554571
PatternRewriter& rewriter) const override {
555-
auto elemType = op.getType().getElementType();
556-
Value lhs = op.getLhs();
557-
Value rhs = op.getRhs();
558-
559572
TypedAttr lhsAttr;
560-
matchPattern(lhs, m_Constant(&lhsAttr));
573+
matchPattern(op.getLhs(), m_Constant(&lhsAttr));
561574

562575
TypedAttr rhsAttr;
563-
matchPattern(rhs, m_Constant(&rhsAttr));
564-
565-
// The canonical form has the constant operand as the RHS.
566-
if (isa<IntegerType>(elemType) && lhsAttr && !rhsAttr) {
567-
rewriter.modifyOpInPlace(op, [op, lhs, rhs] {
568-
op->setOperands(ValueRange{rhs, lhs});
569-
});
570-
return success();
571-
}
576+
matchPattern(op.getRhs(), m_Constant(&rhsAttr));
572577

573578
if (TypedAttr res;
574579
lhsAttr && rhsAttr &&
@@ -613,16 +618,18 @@ struct EvalRemOpPattern : public OpRewritePattern<RemOp> {
613618
}
614619
};
615620

616-
struct EvalReshapeOpPattern : public OpRewritePattern<ReshapeOp> {
617-
using OpRewritePattern::OpRewritePattern;
621+
struct EvalReshapeOpPattern : public ShapeOpRewritePattern<ReshapeOp> {
622+
using ShapeOpRewritePattern::ShapeOpRewritePattern;
623+
618624
LogicalResult matchAndRewrite(ReshapeOp op,
619625
PatternRewriter& rewriter) const override {
620626
auto resultType = op.getType();
621-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
627+
if (failed(validateStaticShapeResult(rewriter, op, resultType)) ||
628+
failed(validateShapeFoldDtype(rewriter, op, resultType)))
622629
return failure();
623630

624631
// Pattern: reshape(cst, shape) -> cst
625-
DenseIntElementsAttr attr;
632+
DenseIntOrFPElementsAttr attr;
626633
if (!matchPattern(op.getOperand(), m_Constant(&attr)))
627634
return rewriter.notifyMatchFailure(op, "expected constant operand");
628635
rewriter.replaceOpWithNewOp<ConstantOp>(op, attr.reshape(resultType));
@@ -635,7 +642,7 @@ struct EvalSelectOpPattern : public OpRewritePattern<SelectOp> {
635642
LogicalResult matchAndRewrite(SelectOp op,
636643
PatternRewriter& rewriter) const override {
637644
auto resultType = op.getType();
638-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
645+
if (failed(validateStaticShapeResult(rewriter, op, resultType)))
639646
return failure();
640647

641648
SmallVector<APSInt> pred, onTrue, onFalse;
@@ -717,7 +724,7 @@ struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
717724
LogicalResult matchAndRewrite(SliceOp op,
718725
PatternRewriter& rewriter) const override {
719726
auto resultType = op.getType();
720-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
727+
if (failed(validateStaticShapeResult(rewriter, op, resultType)))
721728
return failure();
722729

723730
auto operand = op.getOperand();
@@ -778,6 +785,37 @@ struct EvalSubtractOpPattern : public OpRewritePattern<SubtractOp> {
778785
}
779786
};
780787

788+
struct FoldSqrtOpPattern : public OpRewritePattern<mlir::stablehlo::SqrtOp> {
789+
using OpRewritePattern<mlir::stablehlo::SqrtOp>::OpRewritePattern;
790+
791+
LogicalResult matchAndRewrite(mlir::stablehlo::SqrtOp op,
792+
PatternRewriter& rewriter) const final {
793+
TypedAttr lhsAttr;
794+
matchPattern(op.getOperand(), m_Constant(&lhsAttr));
795+
796+
if (!lhsAttr)
797+
return rewriter.notifyMatchFailure(op, "operand not constant");
798+
799+
if (auto res = constFoldUnaryOp<FloatAttr, FloatAttr::ValueType, void>(
800+
lhsAttr, foldSqrt)) {
801+
rewriter.replaceOpWithNewOp<stablehlo::ConstantOp>(
802+
op, op.getType(), llvm::cast<ElementsAttr>(res));
803+
return success();
804+
}
805+
806+
return rewriter.notifyMatchFailure(op, "unable to fold sqrt");
807+
}
808+
809+
static std::optional<APFloat> foldSqrt(const APFloat& a) {
810+
if (a.getSizeInBits(a.getSemantics()) == 64)
811+
return APFloat(std::sqrt(a.convertToDouble()));
812+
813+
if (a.getSizeInBits(a.getSemantics()) == 32)
814+
return APFloat(sqrtf(a.convertToFloat()));
815+
return {};
816+
}
817+
};
818+
781819
struct EvalIotaOpPattern : public OpRewritePattern<IotaOp> {
782820
using OpRewritePattern::OpRewritePattern;
783821
LogicalResult matchAndRewrite(IotaOp op,
@@ -860,7 +898,7 @@ struct EvalTransposeOpPattern : public OpRewritePattern<TransposeOp> {
860898
LogicalResult matchAndRewrite(TransposeOp op,
861899
PatternRewriter& rewriter) const override {
862900
auto resultType = op.getType();
863-
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
901+
if (failed(validateStaticShapeResult(rewriter, op, resultType)))
864902
return failure();
865903

866904
ElementsAttr els;
@@ -916,10 +954,9 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns,
916954

917955
// TODO: Consolidate FoldOp patterns
918956
// One is used by Shape Refinement, the other is a generic folder.
919-
patterns
920-
->add<FoldAddOpPattern, FoldBroadcastInDimSplatPattern,
921-
FoldConcatenateOpPattern, FoldMulOpPattern, FoldSubtractOpPattern>(
922-
context);
957+
patterns->add<FoldAddOpPattern, FoldBroadcastInDimSplatPattern,
958+
FoldConcatenateOpPattern, FoldMulOpPattern,
959+
FoldSubtractOpPattern, FoldSqrtOpPattern>(context);
923960
}
924961

925962
void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns,
@@ -939,7 +976,7 @@ void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns,
939976
patterns->add<EvalMulOpPattern>(context, benefit);
940977
patterns->add<EvalOrOpPattern>(context, benefit);
941978
patterns->add<EvalRemOpPattern>(context, benefit);
942-
patterns->add<EvalReshapeOpPattern>(context, benefit);
979+
patterns->add<EvalReshapeOpPattern>(context, benefit, foldFloat);
943980
patterns->add<EvalSelectOpPattern>(context, benefit);
944981
patterns->add<EvalSignOpPattern>(context, benefit);
945982
patterns->add<EvalSliceOpPattern>(context, benefit);

stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def EmptyI64Array : AttrConstraint<
6868
CPred<"cast<DenseI64ArrayAttr>($_self).empty()">,
6969
"is empty i64 array">;
7070

71+
def AnyOne : AttrConstraint<
72+
CPred<"::mlir::matchPattern($_self, m_AnyAttrOf(m_One(), m_OneFloat()))">,
73+
"is integer one">;
74+
7175
def IntOne : AttrConstraint<
7276
CPred<"::mlir::matchPattern($_self, m_One())">,
7377
"is integer one">;
@@ -303,11 +307,11 @@ def : CanonicalizeConstantToRhs<StableHLO_MulOp>;
303307

304308
// Pattern: multiply(X, 0i) -> 0i
305309
// Multiplication by 0. This fold is not trivial for floats in presence of NaNs
306-
def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)),
310+
def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp:$zero AnyZero:$value)),
307311
(replaceWithValue $zero)>;
308312

309313
// Pattern: multiply(X, 1i) -> X
310-
def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp IntOne:$value)),
314+
def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp AnyOne:$value)),
311315
(replaceWithValue $lhs)>;
312316

313317
////////

stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ struct StablehloTargetIndependentOptimizationPass
4545

4646
LogicalResult initialize(MLIRContext* context) override {
4747
RewritePatternSet patterns_(context);
48-
bool foldFloat = false;
48+
bool foldFloat = true;
4949
populateStablehloCanonicalizationPatterns(context, &patterns_);
5050
populateStablehloAggressiveFolderPatterns(&patterns_, context, foldFloat,
5151
/*benefit=*/2);

0 commit comments

Comments
 (0)