Skip to content

Commit bccacea

Browse files
committed
Add decomposition for SimplifiedLayerNorm
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 24f4192 commit bccacea

File tree

2 files changed

+170
-26
lines changed

2 files changed

+170
-26
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 102 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,29 +2828,29 @@ struct CustomOpFuseMatMulPattern : public OpRewritePattern<ONNXCustomOp> {
28282828
}
28292829
};
28302830

2831-
namespace {
2831+
static constexpr StringLiteral MicrosoftDomainName("com.microsoft");
2832+
static constexpr StringLiteral DefaultONNXDomainName("");
28322833

2833-
[[nodiscard]] bool isCustomMicrosoftOp(
2834-
ONNXCustomOp customOp, StringRef expectedName) {
2834+
[[nodiscard]] bool isCustomOpWithNameAndDialect(
2835+
ONNXCustomOp customOp, StringRef expectedName, StringRef expectedDialect) {
28352836
if (!customOp.getFunctionName().equals_insensitive(expectedName)) {
28362837
return false;
28372838
}
28382839

28392840
const auto domAttr = customOp->getAttrOfType<StringAttr>("domain_name");
2840-
return domAttr && domAttr.getValue().equals_insensitive("com.microsoft");
2841+
return domAttr && domAttr.getValue().equals_insensitive(expectedDialect);
28412842
}
28422843

2843-
} // namespace
2844-
2845-
struct CustomOpMicrosoftToOnnxOps : public OpRewritePattern<ONNXCustomOp> {
2846-
CustomOpMicrosoftToOnnxOps(MLIRContext *context,
2847-
std::string operationNameToRewrite, PatternBenefit benefit = 1)
2848-
: OpRewritePattern<ONNXCustomOp>(context, benefit),
2849-
operationNameToRewrite(std::move(operationNameToRewrite)) {}
2844+
struct CustomOpToOnnxOps : public OpRewritePattern<ONNXCustomOp> {
2845+
CustomOpToOnnxOps(MLIRContext *context, StringRef dialect,
2846+
StringRef operationNameToRewrite, PatternBenefit benefit = 1)
2847+
: OpRewritePattern<ONNXCustomOp>(context, benefit), dialect(dialect),
2848+
operationNameToRewrite(operationNameToRewrite) {}
28502849

28512850
LogicalResult matchAndRewrite(
28522851
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
2853-
if (!isCustomMicrosoftOp(customOp, operationNameToRewrite)) {
2852+
if (!isCustomOpWithNameAndDialect(
2853+
customOp, operationNameToRewrite, dialect)) {
28542854
return failure();
28552855
}
28562856

@@ -2908,12 +2908,13 @@ struct CustomOpMicrosoftToOnnxOps : public OpRewritePattern<ONNXCustomOp> {
29082908
})};
29092909
}
29102910

2911+
const std::string dialect;
29112912
const std::string operationNameToRewrite;
29122913
};
29132914

2914-
struct MicrosoftBiasGelu : public CustomOpMicrosoftToOnnxOps {
2915+
struct MicrosoftBiasGelu : public CustomOpToOnnxOps {
29152916
MicrosoftBiasGelu(MLIRContext *context, PatternBenefit benefit = 1)
2916-
: CustomOpMicrosoftToOnnxOps(context, "BiasGelu", benefit) {}
2917+
: CustomOpToOnnxOps(context, MicrosoftDomainName, "BiasGelu", benefit) {}
29172918

29182919
LogicalResult matchAndRewriteImpl(
29192920
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
@@ -2936,9 +2937,9 @@ struct MicrosoftBiasGelu : public CustomOpMicrosoftToOnnxOps {
29362937
}
29372938
};
29382939

2939-
struct MicrosoftFusedConv : public CustomOpMicrosoftToOnnxOps {
2940+
struct MicrosoftFusedConv : public CustomOpToOnnxOps {
29402941
MicrosoftFusedConv(MLIRContext *context, PatternBenefit benefit = 1)
2941-
: CustomOpMicrosoftToOnnxOps(context, "FusedConv", benefit) {}
2942+
: CustomOpToOnnxOps(context, MicrosoftDomainName, "FusedConv", benefit) {}
29422943

29432944
LogicalResult matchAndRewriteImpl(
29442945
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
@@ -3033,9 +3034,80 @@ struct MicrosoftFusedConv : public CustomOpMicrosoftToOnnxOps {
30333034
}
30343035
};
30353036

3036-
struct MicrosoftSkipLayerNorm : public CustomOpMicrosoftToOnnxOps {
3037+
/// Note: This is an operation in onnxruntime, which is in the ONNX instead of
3038+
/// Microsoft domain for historic reasons.
3039+
struct SimplifiedLayerNorm : public CustomOpToOnnxOps {
3040+
SimplifiedLayerNorm(MLIRContext *ctx, PatternBenefit b = 1)
3041+
: CustomOpToOnnxOps(
3042+
ctx, DefaultONNXDomainName, "SimplifiedLayerNormalization", b) {}
3043+
3044+
LogicalResult matchAndRewriteImpl(
3045+
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
3046+
using namespace onnx_mlir;
3047+
Location loc = customOp.getLoc();
3048+
const int64_t numIn = customOp.getNumOperands();
3049+
assert((numIn >= 1 && numIn <= 3) && "expects 1..3 inputs");
3050+
const int64_t numOut = customOp.getNumResults();
3051+
assert((numOut >= 1 && numOut <= 3) && "expects 1..3 outputs");
3052+
// The onnxruntime version of RMSNorm/SimplifiedLayerNorm supports 1-3
3053+
// outputs, (output, mean, inv_std_var) The version in onnx and onnx-mlir
3054+
// only support output (and inv_std_var in case of onnx-mlir)
3055+
if (numOut > 1) {
3056+
if (!isa<NoneType>(customOp.getResultTypes()[1])) {
3057+
return rewriter.notifyMatchFailure(
3058+
customOp, "Use of mean not supported yet");
3059+
}
3060+
}
3061+
3062+
MultiDialectBuilder<OnnxBuilder> create(rewriter, customOp->getLoc());
3063+
3064+
Value none = create.onnx.none();
3065+
3066+
Value input = customOp.getOperand(0);
3067+
Value scale = customOp.getOperand(1);
3068+
Value bias = none; // layer-norm bias
3069+
3070+
if (numIn >= 3)
3071+
bias = customOp.getOperand(2);
3072+
3073+
auto epsAttr = customOp->getAttrOfType<FloatAttr>("epsilon");
3074+
assert(epsAttr && "Expected Epsilon");
3075+
auto axisAttr = customOp->getAttrOfType<IntegerAttr>("axis");
3076+
assert(axisAttr && "Expected Axis");
3077+
auto stashTypeAttr = customOp->getAttrOfType<IntegerAttr>("stash_type");
3078+
assert(stashTypeAttr && "Expected Stash Type");
3079+
3080+
SmallVector<Type, 2> resultTypes;
3081+
resultTypes.push_back(customOp->getResultTypes()[0]);
3082+
resultTypes.push_back(
3083+
numOut > 2 ? customOp->getResultTypes()[2] : rewriter.getNoneType());
3084+
3085+
auto rms = rewriter.create<ONNXRMSLayerNormalizationOp>(
3086+
loc, resultTypes, input, scale, bias, axisAttr, epsAttr, stashTypeAttr);
3087+
3088+
SmallVector<Value, 3> replace;
3089+
replace.push_back(rms.getResult(0));
3090+
if (numOut > 1)
3091+
replace.push_back(none);
3092+
if (numOut > 2)
3093+
replace.push_back(rms.getResult(1));
3094+
3095+
SmallVector<Value, 4> toCheck(replace.begin(), replace.end());
3096+
toCheck.push_back(none);
3097+
3098+
if (failed(verifyOpsErasingOnError(toCheck, rewriter))) {
3099+
return rewriter.notifyMatchFailure(customOp, "Failed verification");
3100+
}
3101+
3102+
rewriter.replaceOp(customOp, replace);
3103+
return success();
3104+
}
3105+
};
3106+
3107+
struct MicrosoftSkipLayerNorm : public CustomOpToOnnxOps {
30373108
MicrosoftSkipLayerNorm(MLIRContext *ctx, PatternBenefit b = 1)
3038-
: CustomOpMicrosoftToOnnxOps(ctx, "SkipLayerNormalization", b) {}
3109+
: CustomOpToOnnxOps(
3110+
ctx, MicrosoftDomainName, "SkipLayerNormalization", b) {}
30393111

30403112
LogicalResult matchAndRewriteImpl(
30413113
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
@@ -3082,21 +3154,21 @@ struct MicrosoftSkipLayerNorm : public CustomOpMicrosoftToOnnxOps {
30823154

30833155
const auto si64Type = rewriter.getIntegerType(64, /*signed*/ true);
30843156

3085-
auto ln = rewriter.create<ONNXLayerNormalizationOp>(loc, resultTypes, sumIS,
3086-
gamma, beta, /*axis*/
3157+
auto rms = rewriter.create<ONNXLayerNormalizationOp>(loc, resultTypes,
3158+
sumIS, gamma, beta, /*axis*/
30873159
rewriter.getIntegerAttr(si64Type, -1), epsAttr,
30883160
/*stashType*/ rewriter.getIntegerAttr(si64Type, 1));
30893161

30903162
SmallVector<Value, 4> replace;
3091-
replace.push_back(ln.getResult(0));
3163+
replace.push_back(rms.getResult(0));
30923164
if (numOut >= 2)
3093-
replace.push_back(ln.getResult(1)); // mean
3165+
replace.push_back(rms.getResult(1)); // mean
30943166
if (numOut >= 3)
3095-
replace.push_back(ln.getResult(2)); // inv_std_var
3167+
replace.push_back(rms.getResult(2)); // inv_std_var
30963168
if (numOut == 4)
30973169
replace.push_back(sumIS); // input_skip_bias_sum
30983170

3099-
SmallVector<Value, 6> toCheck(replace.begin(), replace.end());
3171+
SmallVector<Value, 7> toCheck(replace.begin(), replace.end());
31003172
toCheck.push_back(none);
31013173
toCheck.push_back(skipAdd);
31023174
toCheck.push_back(sumIS);
@@ -3111,8 +3183,11 @@ struct MicrosoftSkipLayerNorm : public CustomOpMicrosoftToOnnxOps {
31113183
};
31123184

31133185
template <typename OpToCreate>
3114-
struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpMicrosoftToOnnxOps {
3115-
using CustomOpMicrosoftToOnnxOps::CustomOpMicrosoftToOnnxOps;
3186+
struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpToOnnxOps {
3187+
CustomOpMicrosoftToSingleOnnxOp(MLIRContext *context,
3188+
StringRef operationNameToRewrite, PatternBenefit benefit = 1)
3189+
: CustomOpToOnnxOps(
3190+
context, MicrosoftDomainName, operationNameToRewrite, benefit) {}
31163191

31173192
LogicalResult matchAndRewriteImpl(
31183193
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
@@ -3518,6 +3593,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
35183593
patterns.insert<MicrosoftBiasGelu>(context);
35193594
patterns.insert<MicrosoftFusedConv>(context);
35203595
patterns.insert<MicrosoftSkipLayerNorm>(context);
3596+
patterns.insert<SimplifiedLayerNorm>(context);
35213597
patterns.insert<DecomposeSlicePadPattern>(context);
35223598
patterns.insert<DecomposeScatterNDPattern>(context);
35233599
patterns.insert<SoftmaxCrossEntropyPattern>(context);

test/mlir/onnx/onnx_decompose_customop.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,71 @@ func.func @skip_layernorm_four_outputs(%input: tensor<2x4x8xf32>, %skip: tensor<
447447
// CHECK: onnx.Return [[VAR_Y_]], [[VAR_Mean_]], [[VAR_InvStdDev_]], [[VAR_1_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
448448
}
449449

450+
451+
// -----
452+
// SimplifiedLayerNormalization: 2 inputs (no bias), 1 output
453+
454+
func.func @simplified_layernorm_basic(%input: tensor<2x4x8xf32>, %scale: tensor<8xf32>) -> tensor<2x4x8xf32> {
455+
%r = "onnx.Custom"(%input, %scale) {domain_name = "", function_name = "SimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32, axis = -1 : si64, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
456+
onnx.Return %r : tensor<2x4x8xf32>
457+
// CHECK-LABEL: func.func @simplified_layernorm_basic
458+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
459+
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
460+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none)
461+
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
462+
}
463+
464+
465+
// -----
466+
// SimplifiedLayerNormalization: 3 inputs (with bias), 1 output
467+
468+
func.func @simplified_layernorm_bias(%input: tensor<2x4x8xf32>, %scale: tensor<8xf32>, %bias: tensor<8xf32>) -> tensor<2x4x8xf32> {
469+
%r = "onnx.Custom"(%input, %scale, %bias) {domain_name = "", function_name = "SimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32, axis = -1 : si64, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
470+
onnx.Return %r : tensor<2x4x8xf32>
471+
// CHECK-LABEL: func.func @simplified_layernorm_bias
472+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
473+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none)
474+
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
475+
}
476+
477+
478+
// -----
479+
// SimplifiedLayerNormalization: 2 inputs, 2 outputs (output, mean) -> mean is unused
480+
481+
func.func @simplified_layernorm_two_outputs_mean_unused(%input: tensor<2x4x8xf32>, %scale: tensor<8xf32>) -> tensor<2x4x8xf32> {
482+
%r0, %r1 = "onnx.Custom"(%input, %scale) {domain_name = "", function_name = "SimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32, axis = -1 : si64, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
483+
onnx.Return %r0 : tensor<2x4x8xf32>
484+
// CHECK-LABEL: func.func @simplified_layernorm_two_outputs_mean_unused
485+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
486+
// CHECK: [[VAR_0_:%.+]]:2 = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {axis = -1 : si64, domain_name = "", epsilon = 9.99999974E-6 : f32, function_name = "SimplifiedLayerNormalization", stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
487+
// CHECK: onnx.Return [[VAR_0_]]#0 : tensor<2x4x8xf32>
488+
}
489+
490+
491+
// -----
492+
// SimplifiedLayerNormalization: 2 inputs, 3 outputs (output, mean, inv_std_var) -> mean is unused
493+
494+
func.func @simplified_layernorm_three_outputs_mean_unused(%input: tensor<2x4x8xf32>, %scale: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
495+
%r0, %r1, %r2 = "onnx.Custom"(%input, %scale) {domain_name = "", function_name = "SimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32, axis = -1 : si64, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none, tensor<2x4x1xf32>)
496+
onnx.Return %r0, %r2 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
497+
// CHECK-LABEL: func.func @simplified_layernorm_three_outputs_mean_unused
498+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
499+
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
500+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
501+
// CHECK: onnx.Return [[VAR_Y_]], [[VAR_InvStdDev_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>
502+
}
503+
504+
505+
// -----
506+
// Negative: SimplifiedLayerNormalization: 2 inputs, 2 outputs (output, mean) -> mean is used
507+
508+
func.func @simplified_layernorm_two_outputs_mean_used(%input: tensor<2x4x8xf32>, %scale: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
509+
%r0, %r1 = "onnx.Custom"(%input, %scale) {domain_name = "", function_name = "SimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32, axis = -1 : si64, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
510+
onnx.Return %r0, %r1 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
511+
// CHECK-LABEL: func.func @simplified_layernorm_two_outputs_mean_used
512+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
513+
// CHECK: [[VAR_0_:%.+]]:2 = "onnx.Custom"([[PARAM_0_]], [[PARAM_1_]]) {axis = -1 : si64, domain_name = "", epsilon = 9.99999974E-6 : f32, function_name = "SimplifiedLayerNormalization", stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
514+
// CHECK: onnx.Return [[VAR_0_]]#0, [[VAR_0_]]#1 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
515+
}
516+
517+

0 commit comments

Comments
 (0)