@@ -12,9 +12,12 @@ See the License for the specific language governing permissions and
1212limitations under the License.
1313==============================================================================*/
1414
15+ #include < cmath>
1516#include < cstddef>
1617#include < cstdint>
18+ #include < functional>
1719#include < numeric>
20+ #include < optional>
1821#include < utility>
1922
2023#include " llvm/ADT/APInt.h"
@@ -23,11 +26,14 @@ limitations under the License.
2326#include " llvm/ADT/STLExtras.h"
2427#include " llvm/ADT/SmallVector.h"
2528#include " llvm/ADT/StringRef.h"
29+ #include " llvm/Support/Casting.h"
2630#include " llvm/Support/Debug.h"
2731#include " llvm/Support/ErrorHandling.h"
32+ #include " llvm/Support/LogicalResult.h"
2833#include " mlir/Dialect/CommonFolders.h"
2934#include " mlir/Dialect/Func/IR/FuncOps.h"
3035#include " mlir/Dialect/Utils/IndexingUtils.h"
36+ #include " mlir/IR/BuiltinAttributeInterfaces.h"
3137#include " mlir/IR/BuiltinAttributes.h"
3238#include " mlir/IR/BuiltinTypeInterfaces.h"
3339#include " mlir/IR/BuiltinTypes.h"
@@ -86,7 +92,7 @@ APSInt getAPSInt(Type type, uint64_t value) {
8692 /* isUnsigned=*/ isUnsigned);
8793}
8894
89- LogicalResult validateResultTypeForEval (PatternRewriter& rewriter,
95+ LogicalResult validateStaticShapeResult (PatternRewriter& rewriter,
9096 Operation* op, ShapedType resultType) {
9197 if (!resultType.hasStaticShape ())
9298 return rewriter.notifyMatchFailure (
@@ -212,7 +218,7 @@ template <typename OpType, typename FuncType>
212218LogicalResult evalElementwise (PatternRewriter& rewriter, OpType op,
213219 FuncType fn) {
214220 auto resultType = op.getType ();
215- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
221+ if (failed (validateStaticShapeResult (rewriter, op, resultType)))
216222 return failure ();
217223
218224 if (!isa<IntegerType>(resultType.getElementType ()))
@@ -279,6 +285,28 @@ struct FoldAddOpPattern final : OpRewritePattern<mlir::stablehlo::AddOp> {
279285 }
280286};
281287
288+ // A base class to use for patterns that may be used for integer shape math,
289+ // but also may be used for general folding of floats.
290+ template <typename OpType>
291+ struct ShapeOpRewritePattern : public OpRewritePattern <OpType> {
292+ ShapeOpRewritePattern (MLIRContext* context, PatternBenefit benefit,
293+ bool foldFloat_)
294+ : OpRewritePattern<OpType>(context, benefit), foldFloat{foldFloat_} {}
295+
296+ using OpRewritePattern<OpType>::OpRewritePattern;
297+ using OpRewritePattern<OpType>::matchAndRewrite;
298+
299+ LogicalResult validateShapeFoldDtype (PatternRewriter& rewriter, OpType op,
300+ ShapedType resultType) const {
301+ if (resultType.getElementType ().isInteger ()) return success ();
302+ if (foldFloat && isa<FloatType>(resultType.getElementType ()))
303+ return success ();
304+ return rewriter.notifyMatchFailure (op, " skipping fold of shape op dtype" );
305+ }
306+
307+ bool foldFloat;
308+ };
309+
282310struct EvalAddOpShapePattern : public OpRewritePattern <AddOp> {
283311 using OpRewritePattern::OpRewritePattern;
284312 LogicalResult matchAndRewrite (AddOp op,
@@ -327,7 +355,7 @@ struct EvalBroadcastInDimOpPattern : public OpRewritePattern<BroadcastInDimOp> {
327355 LogicalResult matchAndRewrite (BroadcastInDimOp op,
328356 PatternRewriter& rewriter) const override {
329357 auto resultType = op.getType ();
330- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
358+ if (failed (validateStaticShapeResult (rewriter, op, resultType)))
331359 return failure ();
332360
333361 auto operandType = op.getOperand ().getType ();
@@ -442,7 +470,7 @@ struct EvalConcatenateOpPattern : public OpRewritePattern<ConcatenateOp> {
442470 LogicalResult matchAndRewrite (ConcatenateOp op,
443471 PatternRewriter& rewriter) const override {
444472 auto resultType = op.getType ();
445- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
473+ if (failed (validateStaticShapeResult (rewriter, op, resultType)))
446474 return failure ();
447475
448476 if (op.getDimension () != 0 )
@@ -460,31 +488,23 @@ struct EvalConcatenateOpPattern : public OpRewritePattern<ConcatenateOp> {
460488 }
461489};
462490
463- struct EvalConvertOpPattern : public OpRewritePattern <ConvertOp> {
464- using OpRewritePattern::OpRewritePattern;
465-
466- EvalConvertOpPattern (MLIRContext* context, PatternBenefit benefit,
467- bool foldFloat_)
468- : OpRewritePattern<ConvertOp>(context, benefit), foldFloat{foldFloat_} {}
491+ struct EvalConvertOpPattern : public ShapeOpRewritePattern <ConvertOp> {
492+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
469493
470494 LogicalResult matchAndRewrite (ConvertOp op,
471495 PatternRewriter& rewriter) const override {
472496 auto operand = op.getOperand ();
473497 RankedTensorType resultType = op.getType ();
474498
475- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
499+ if (failed (validateStaticShapeResult (rewriter, op, resultType)) ||
500+ failed (validateShapeFoldDtype (rewriter, op, resultType)))
476501 return failure ();
477502
478503 auto operandElemType = getElementTypeOrSelf (operand.getType ());
479504 auto resultElemType = getElementTypeOrSelf (resultType);
480- if (!(operandElemType.isInteger () && resultElemType.isInteger ()) &&
481- !foldFloat)
482- return rewriter.notifyMatchFailure (op,
483- " lossy computations are not allowed" );
484-
485- if (!resultElemType.isIntOrFloat ())
486- return rewriter.notifyMatchFailure (
487- op, " expected integer or float result tensor type" );
505+ if (!foldFloat &&
506+ (isa<FloatType>(operandElemType) || isa<FloatType>(resultElemType)))
507+ return rewriter.notifyMatchFailure (op, " skipping fold of float convert" );
488508
489509 DenseIntOrFPElementsAttr elements;
490510 if (!matchPattern (operand, m_Constant (&elements)))
@@ -493,9 +513,6 @@ struct EvalConvertOpPattern : public OpRewritePattern<ConvertOp> {
493513
494514 return evalConvert (rewriter, op, elements, resultType);
495515 }
496-
497- private:
498- bool foldFloat;
499516};
500517
501518struct EvalDivOpPattern : public OpRewritePattern <DivOp> {
@@ -513,7 +530,7 @@ struct EvalGetDimensionSizeOpPattern
513530 LogicalResult matchAndRewrite (GetDimensionSizeOp op,
514531 PatternRewriter& rewriter) const override {
515532 auto resultType = op.getType ();
516- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
533+ if (failed (validateStaticShapeResult (rewriter, op, resultType)))
517534 return failure ();
518535
519536 auto operandType = op.getOperand ().getType ();
@@ -552,23 +569,11 @@ struct FoldMulOpPattern final : OpRewritePattern<mlir::stablehlo::MulOp> {
552569
553570 LogicalResult matchAndRewrite (mlir::stablehlo::MulOp op,
554571 PatternRewriter& rewriter) const override {
555- auto elemType = op.getType ().getElementType ();
556- Value lhs = op.getLhs ();
557- Value rhs = op.getRhs ();
558-
559572 TypedAttr lhsAttr;
560- matchPattern (lhs , m_Constant (&lhsAttr));
573+ matchPattern (op. getLhs () , m_Constant (&lhsAttr));
561574
562575 TypedAttr rhsAttr;
563- matchPattern (rhs, m_Constant (&rhsAttr));
564-
565- // The canonical form has the constant operand as the RHS.
566- if (isa<IntegerType>(elemType) && lhsAttr && !rhsAttr) {
567- rewriter.modifyOpInPlace (op, [op, lhs, rhs] {
568- op->setOperands (ValueRange{rhs, lhs});
569- });
570- return success ();
571- }
576+ matchPattern (op.getRhs (), m_Constant (&rhsAttr));
572577
573578 if (TypedAttr res;
574579 lhsAttr && rhsAttr &&
@@ -613,16 +618,18 @@ struct EvalRemOpPattern : public OpRewritePattern<RemOp> {
613618 }
614619};
615620
616- struct EvalReshapeOpPattern : public OpRewritePattern <ReshapeOp> {
617- using OpRewritePattern::OpRewritePattern;
621+ struct EvalReshapeOpPattern : public ShapeOpRewritePattern <ReshapeOp> {
622+ using ShapeOpRewritePattern::ShapeOpRewritePattern;
623+
618624 LogicalResult matchAndRewrite (ReshapeOp op,
619625 PatternRewriter& rewriter) const override {
620626 auto resultType = op.getType ();
621- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
627+ if (failed (validateStaticShapeResult (rewriter, op, resultType)) ||
628+ failed (validateShapeFoldDtype (rewriter, op, resultType)))
622629 return failure ();
623630
624631 // Pattern: reshape(cst, shape) -> cst
625- DenseIntElementsAttr attr;
632+ DenseIntOrFPElementsAttr attr;
626633 if (!matchPattern (op.getOperand (), m_Constant (&attr)))
627634 return rewriter.notifyMatchFailure (op, " expected constant operand" );
628635 rewriter.replaceOpWithNewOp <ConstantOp>(op, attr.reshape (resultType));
@@ -635,7 +642,7 @@ struct EvalSelectOpPattern : public OpRewritePattern<SelectOp> {
635642 LogicalResult matchAndRewrite (SelectOp op,
636643 PatternRewriter& rewriter) const override {
637644 auto resultType = op.getType ();
638- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
645+ if (failed (validateStaticShapeResult (rewriter, op, resultType)))
639646 return failure ();
640647
641648 SmallVector<APSInt> pred, onTrue, onFalse;
@@ -717,7 +724,7 @@ struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
717724 LogicalResult matchAndRewrite (SliceOp op,
718725 PatternRewriter& rewriter) const override {
719726 auto resultType = op.getType ();
720- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
727+ if (failed (validateStaticShapeResult (rewriter, op, resultType)))
721728 return failure ();
722729
723730 auto operand = op.getOperand ();
@@ -778,6 +785,37 @@ struct EvalSubtractOpPattern : public OpRewritePattern<SubtractOp> {
778785 }
779786};
780787
788+ struct FoldSqrtOpPattern : public OpRewritePattern <mlir::stablehlo::SqrtOp> {
789+ using OpRewritePattern<mlir::stablehlo::SqrtOp>::OpRewritePattern;
790+
791+ LogicalResult matchAndRewrite (mlir::stablehlo::SqrtOp op,
792+ PatternRewriter& rewriter) const final {
793+ TypedAttr lhsAttr;
794+ matchPattern (op.getOperand (), m_Constant (&lhsAttr));
795+
796+ if (!lhsAttr)
797+ return rewriter.notifyMatchFailure (op, " operand not constant" );
798+
799+ if (auto res = constFoldUnaryOp<FloatAttr, FloatAttr::ValueType, void >(
800+ lhsAttr, foldSqrt)) {
801+ rewriter.replaceOpWithNewOp <stablehlo::ConstantOp>(
802+ op, op.getType (), llvm::cast<ElementsAttr>(res));
803+ return success ();
804+ }
805+
806+ return rewriter.notifyMatchFailure (op, " unable to fold sqrt" );
807+ }
808+
809+ static std::optional<APFloat> foldSqrt (const APFloat& a) {
810+ if (a.getSizeInBits (a.getSemantics ()) == 64 )
811+ return APFloat (std::sqrt (a.convertToDouble ()));
812+
813+ if (a.getSizeInBits (a.getSemantics ()) == 32 )
814+ return APFloat (sqrtf (a.convertToFloat ()));
815+ return {};
816+ }
817+ };
818+
781819struct EvalIotaOpPattern : public OpRewritePattern <IotaOp> {
782820 using OpRewritePattern::OpRewritePattern;
783821 LogicalResult matchAndRewrite (IotaOp op,
@@ -860,7 +898,7 @@ struct EvalTransposeOpPattern : public OpRewritePattern<TransposeOp> {
860898 LogicalResult matchAndRewrite (TransposeOp op,
861899 PatternRewriter& rewriter) const override {
862900 auto resultType = op.getType ();
863- if (failed (validateResultTypeForEval (rewriter, op, resultType)))
901+ if (failed (validateStaticShapeResult (rewriter, op, resultType)))
864902 return failure ();
865903
866904 ElementsAttr els;
@@ -916,10 +954,9 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns,
916954
917955 // TODO: Consolidate FoldOp patterns
918956 // One is used by Shape Refinement, the other is a generic folder.
919- patterns
920- ->add <FoldAddOpPattern, FoldBroadcastInDimSplatPattern,
921- FoldConcatenateOpPattern, FoldMulOpPattern, FoldSubtractOpPattern>(
922- context);
957+ patterns->add <FoldAddOpPattern, FoldBroadcastInDimSplatPattern,
958+ FoldConcatenateOpPattern, FoldMulOpPattern,
959+ FoldSubtractOpPattern, FoldSqrtOpPattern>(context);
923960}
924961
925962void populateStablehloShapeFolderPatterns (RewritePatternSet* patterns,
@@ -939,7 +976,7 @@ void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns,
939976 patterns->add <EvalMulOpPattern>(context, benefit);
940977 patterns->add <EvalOrOpPattern>(context, benefit);
941978 patterns->add <EvalRemOpPattern>(context, benefit);
942- patterns->add <EvalReshapeOpPattern>(context, benefit);
979+ patterns->add <EvalReshapeOpPattern>(context, benefit, foldFloat );
943980 patterns->add <EvalSelectOpPattern>(context, benefit);
944981 patterns->add <EvalSignOpPattern>(context, benefit);
945982 patterns->add <EvalSliceOpPattern>(context, benefit);
0 commit comments