@@ -2828,29 +2828,29 @@ struct CustomOpFuseMatMulPattern : public OpRewritePattern<ONNXCustomOp> {
28282828 }
28292829};
28302830
2831- namespace {
2831+ static constexpr StringLiteral MicrosoftDomainName (" com.microsoft" );
2832+ static constexpr StringLiteral DefaultONNXDomainName (" " );
28322833
2833- [[nodiscard]] bool isCustomMicrosoftOp (
2834- ONNXCustomOp customOp, StringRef expectedName) {
2834+ [[nodiscard]] bool isCustomOpWithNameAndDialect (
2835+ ONNXCustomOp customOp, StringRef expectedName, StringRef expectedDialect ) {
28352836 if (!customOp.getFunctionName ().equals_insensitive (expectedName)) {
28362837 return false ;
28372838 }
28382839
28392840 const auto domAttr = customOp->getAttrOfType <StringAttr>(" domain_name" );
2840- return domAttr && domAttr.getValue ().equals_insensitive (" com.microsoft " );
2841+ return domAttr && domAttr.getValue ().equals_insensitive (expectedDialect );
28412842}
28422843
2843- } // namespace
2844-
2845- struct CustomOpMicrosoftToOnnxOps : public OpRewritePattern <ONNXCustomOp> {
2846- CustomOpMicrosoftToOnnxOps (MLIRContext *context,
2847- std::string operationNameToRewrite, PatternBenefit benefit = 1 )
2848- : OpRewritePattern<ONNXCustomOp>(context, benefit),
2849- operationNameToRewrite (std::move(operationNameToRewrite)) {}
2844+ struct CustomOpToOnnxOps : public OpRewritePattern <ONNXCustomOp> {
2845+ CustomOpToOnnxOps (MLIRContext *context, StringRef dialect,
2846+ StringRef operationNameToRewrite, PatternBenefit benefit = 1 )
2847+ : OpRewritePattern<ONNXCustomOp>(context, benefit), dialect(dialect),
2848+ operationNameToRewrite (operationNameToRewrite) {}
28502849
28512850 LogicalResult matchAndRewrite (
28522851 ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
2853- if (!isCustomMicrosoftOp (customOp, operationNameToRewrite)) {
2852+ if (!isCustomOpWithNameAndDialect (
2853+ customOp, operationNameToRewrite, dialect)) {
28542854 return failure ();
28552855 }
28562856
@@ -2908,12 +2908,13 @@ struct CustomOpMicrosoftToOnnxOps : public OpRewritePattern<ONNXCustomOp> {
29082908 })};
29092909 }
29102910
2911+ const std::string dialect;
29112912 const std::string operationNameToRewrite;
29122913};
29132914
2914- struct MicrosoftBiasGelu : public CustomOpMicrosoftToOnnxOps {
2915+ struct MicrosoftBiasGelu : public CustomOpToOnnxOps {
29152916 MicrosoftBiasGelu (MLIRContext *context, PatternBenefit benefit = 1 )
2916- : CustomOpMicrosoftToOnnxOps (context, " BiasGelu" , benefit) {}
2917+ : CustomOpToOnnxOps (context, MicrosoftDomainName , " BiasGelu" , benefit) {}
29172918
29182919 LogicalResult matchAndRewriteImpl (
29192920 ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
@@ -2936,9 +2937,9 @@ struct MicrosoftBiasGelu : public CustomOpMicrosoftToOnnxOps {
29362937 }
29372938};
29382939
2939- struct MicrosoftFusedConv : public CustomOpMicrosoftToOnnxOps {
2940+ struct MicrosoftFusedConv : public CustomOpToOnnxOps {
29402941 MicrosoftFusedConv (MLIRContext *context, PatternBenefit benefit = 1 )
2941- : CustomOpMicrosoftToOnnxOps (context, " FusedConv" , benefit) {}
2942+ : CustomOpToOnnxOps (context, MicrosoftDomainName , " FusedConv" , benefit) {}
29422943
29432944 LogicalResult matchAndRewriteImpl (
29442945 ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
@@ -3033,9 +3034,80 @@ struct MicrosoftFusedConv : public CustomOpMicrosoftToOnnxOps {
30333034 }
30343035};
30353036
3036- struct MicrosoftSkipLayerNorm : public CustomOpMicrosoftToOnnxOps {
3037+ // / Note: This is an operation in onnxruntime, which is in the ONNX instead of
3038+ // / Microsoft domain for historic reasons.
3039+ struct SimplifiedLayerNorm : public CustomOpToOnnxOps {
3040+ SimplifiedLayerNorm (MLIRContext *ctx, PatternBenefit b = 1 )
3041+ : CustomOpToOnnxOps(
3042+ ctx, DefaultONNXDomainName, " SimplifiedLayerNormalization" , b) {}
3043+
3044+ LogicalResult matchAndRewriteImpl (
3045+ ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
3046+ using namespace onnx_mlir ;
3047+ Location loc = customOp.getLoc ();
3048+ const int64_t numIn = customOp.getNumOperands ();
3049+ assert ((numIn >= 1 && numIn <= 3 ) && " expects 1..3 inputs" );
3050+ const int64_t numOut = customOp.getNumResults ();
3051+ assert ((numOut >= 1 && numOut <= 3 ) && " expects 1..3 outputs" );
3052+ // The onnxruntime version of RMSNorm/SimplifiedLayerNorm supports 1-3
3053+ // outputs, (output, mean, inv_std_var) The version in onnx and onnx-mlir
3054+ // only support output (and inv_std_var in case of onnx-mlir)
3055+ if (numOut > 1 ) {
3056+ if (!isa<NoneType>(customOp.getResultTypes ()[1 ])) {
3057+ return rewriter.notifyMatchFailure (
3058+ customOp, " Use of mean not supported yet" );
3059+ }
3060+ }
3061+
3062+ MultiDialectBuilder<OnnxBuilder> create (rewriter, customOp->getLoc ());
3063+
3064+ Value none = create.onnx .none ();
3065+
3066+ Value input = customOp.getOperand (0 );
3067+ Value scale = customOp.getOperand (1 );
3068+ Value bias = none; // layer-norm bias
3069+
3070+ if (numIn >= 3 )
3071+ bias = customOp.getOperand (2 );
3072+
3073+ auto epsAttr = customOp->getAttrOfType <FloatAttr>(" epsilon" );
3074+ assert (epsAttr && " Expected Epsilon" );
3075+ auto axisAttr = customOp->getAttrOfType <IntegerAttr>(" axis" );
3076+ assert (axisAttr && " Expected Axis" );
3077+ auto stashTypeAttr = customOp->getAttrOfType <IntegerAttr>(" stash_type" );
3078+ assert (stashTypeAttr && " Expected Stash Type" );
3079+
3080+ SmallVector<Type, 2 > resultTypes;
3081+ resultTypes.push_back (customOp->getResultTypes ()[0 ]);
3082+ resultTypes.push_back (
3083+ numOut > 2 ? customOp->getResultTypes ()[2 ] : rewriter.getNoneType ());
3084+
3085+ auto rms = rewriter.create <ONNXRMSLayerNormalizationOp>(
3086+ loc, resultTypes, input, scale, bias, axisAttr, epsAttr, stashTypeAttr);
3087+
3088+ SmallVector<Value, 3 > replace;
3089+ replace.push_back (rms.getResult (0 ));
3090+ if (numOut > 1 )
3091+ replace.push_back (none);
3092+ if (numOut > 2 )
3093+ replace.push_back (rms.getResult (1 ));
3094+
3095+ SmallVector<Value, 4 > toCheck (replace.begin (), replace.end ());
3096+ toCheck.push_back (none);
3097+
3098+ if (failed (verifyOpsErasingOnError (toCheck, rewriter))) {
3099+ return rewriter.notifyMatchFailure (customOp, " Failed verification" );
3100+ }
3101+
3102+ rewriter.replaceOp (customOp, replace);
3103+ return success ();
3104+ }
3105+ };
3106+
3107+ struct MicrosoftSkipLayerNorm : public CustomOpToOnnxOps {
30373108 MicrosoftSkipLayerNorm (MLIRContext *ctx, PatternBenefit b = 1 )
3038- : CustomOpMicrosoftToOnnxOps(ctx, " SkipLayerNormalization" , b) {}
3109+ : CustomOpToOnnxOps(
3110+ ctx, MicrosoftDomainName, " SkipLayerNormalization" , b) {}
30393111
30403112 LogicalResult matchAndRewriteImpl (
30413113 ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
@@ -3082,21 +3154,21 @@ struct MicrosoftSkipLayerNorm : public CustomOpMicrosoftToOnnxOps {
30823154
30833155 const auto si64Type = rewriter.getIntegerType (64 , /* signed*/ true );
30843156
3085- auto ln = rewriter.create <ONNXLayerNormalizationOp>(loc, resultTypes, sumIS ,
3086- gamma, beta, /* axis*/
3157+ auto rms = rewriter.create <ONNXLayerNormalizationOp>(loc, resultTypes,
3158+ sumIS, gamma, beta, /* axis*/
30873159 rewriter.getIntegerAttr (si64Type, -1 ), epsAttr,
30883160 /* stashType*/ rewriter.getIntegerAttr (si64Type, 1 ));
30893161
30903162 SmallVector<Value, 4 > replace;
3091- replace.push_back (ln .getResult (0 ));
3163+ replace.push_back (rms .getResult (0 ));
30923164 if (numOut >= 2 )
3093- replace.push_back (ln .getResult (1 )); // mean
3165+ replace.push_back (rms .getResult (1 )); // mean
30943166 if (numOut >= 3 )
3095- replace.push_back (ln .getResult (2 )); // inv_std_var
3167+ replace.push_back (rms .getResult (2 )); // inv_std_var
30963168 if (numOut == 4 )
30973169 replace.push_back (sumIS); // input_skip_bias_sum
30983170
3099- SmallVector<Value, 6 > toCheck (replace.begin (), replace.end ());
3171+ SmallVector<Value, 7 > toCheck (replace.begin (), replace.end ());
31003172 toCheck.push_back (none);
31013173 toCheck.push_back (skipAdd);
31023174 toCheck.push_back (sumIS);
@@ -3111,8 +3183,11 @@ struct MicrosoftSkipLayerNorm : public CustomOpMicrosoftToOnnxOps {
31113183};
31123184
31133185template <typename OpToCreate>
3114- struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpMicrosoftToOnnxOps {
3115- using CustomOpMicrosoftToOnnxOps::CustomOpMicrosoftToOnnxOps;
3186+ struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpToOnnxOps {
3187+ CustomOpMicrosoftToSingleOnnxOp (MLIRContext *context,
3188+ StringRef operationNameToRewrite, PatternBenefit benefit = 1 )
3189+ : CustomOpToOnnxOps(
3190+ context, MicrosoftDomainName, operationNameToRewrite, benefit) {}
31163191
31173192 LogicalResult matchAndRewriteImpl (
31183193 ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
@@ -3518,6 +3593,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
35183593 patterns.insert <MicrosoftBiasGelu>(context);
35193594 patterns.insert <MicrosoftFusedConv>(context);
35203595 patterns.insert <MicrosoftSkipLayerNorm>(context);
3596+ patterns.insert <SimplifiedLayerNorm>(context);
35213597 patterns.insert <DecomposeSlicePadPattern>(context);
35223598 patterns.insert <DecomposeScatterNDPattern>(context);
35233599 patterns.insert <SoftmaxCrossEntropyPattern>(context);
0 commit comments