Skip to content

Commit 0019e50

Browse files
committed
Add decomposition of SkipSimplifiedLayerNormalization
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent bccacea commit 0019e50

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3182,6 +3182,89 @@ struct MicrosoftSkipLayerNorm : public CustomOpToOnnxOps {
31823182
}
31833183
};
31843184

3185+
struct MicrosoftSkipSimplifiedLayerNorm : public CustomOpToOnnxOps {
3186+
MicrosoftSkipSimplifiedLayerNorm(MLIRContext *ctx, PatternBenefit b = 1)
3187+
: CustomOpToOnnxOps(
3188+
ctx, MicrosoftDomainName, "SkipSimplifiedLayerNormalization", b) {}
3189+
3190+
LogicalResult matchAndRewriteImpl(
3191+
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
3192+
using namespace onnx_mlir;
3193+
Location loc = customOp.getLoc();
3194+
const int64_t numIn = customOp.getNumOperands();
3195+
assert((numIn >= 3 && numIn <= 4) && "expects 3..4 inputs");
3196+
const int64_t numOut = customOp.getNumResults();
3197+
assert((numOut >= 1 && numOut <= 4) && "expects 1..4 outputs");
3198+
3199+
MultiDialectBuilder<OnnxBuilder> create(rewriter, customOp->getLoc());
3200+
3201+
Value none = create.onnx.none();
3202+
3203+
Value input = customOp.getOperand(0);
3204+
Value skip = customOp.getOperand(1);
3205+
Value gamma = customOp.getOperand(2);
3206+
Value bias; // pre-norm bias
3207+
3208+
if (numIn >= 4)
3209+
bias = customOp.getOperand(3);
3210+
3211+
auto epsAttr = customOp->getAttrOfType<FloatAttr>("epsilon");
3212+
assert(epsAttr && "Expected Epsilon");
3213+
3214+
Value skipAdd = create.onnx.add(input, skip);
3215+
Value sumIS;
3216+
if (bias) {
3217+
sumIS = create.onnx.add(skipAdd, bias);
3218+
} else {
3219+
sumIS = skipAdd;
3220+
skipAdd = nullptr;
3221+
}
3222+
3223+
SmallVector<Type, 3> resultTypes;
3224+
resultTypes.push_back(customOp->getResultTypes()[0]);
3225+
resultTypes.push_back(
3226+
numOut > 1 ? customOp->getResultTypes()[1] : rewriter.getNoneType());
3227+
resultTypes.push_back(
3228+
numOut > 2 ? customOp->getResultTypes()[2] : rewriter.getNoneType());
3229+
3230+
const auto si64Type = rewriter.getIntegerType(64, /*signed*/ true);
3231+
3232+
const SmallVector<NamedAttribute, 5> simplifiedLayerNormAttrs{
3233+
rewriter.getNamedAttr(
3234+
"domain_name", rewriter.getStringAttr(DefaultONNXDomainName)),
3235+
rewriter.getNamedAttr("function_name",
3236+
rewriter.getStringAttr("SimplifiedLayerNormalization")),
3237+
rewriter.getNamedAttr("axis", rewriter.getIntegerAttr(si64Type, -1)),
3238+
rewriter.getNamedAttr("epsilon", epsAttr),
3239+
rewriter.getNamedAttr(
3240+
"stash_type", rewriter.getIntegerAttr(si64Type, 1))};
3241+
3242+
auto skipLayerNorm = rewriter.create<ONNXCustomOp>(
3243+
loc, resultTypes, ValueRange{sumIS, gamma}, simplifiedLayerNormAttrs);
3244+
3245+
SmallVector<Value, 4> replace;
3246+
replace.push_back(skipLayerNorm.getResult(0));
3247+
if (numOut >= 2)
3248+
replace.push_back(skipLayerNorm.getResult(1)); // mean
3249+
if (numOut >= 3)
3250+
replace.push_back(skipLayerNorm.getResult(2)); // inv_std_var
3251+
if (numOut == 4)
3252+
replace.push_back(sumIS); // input_skip_bias_sum
3253+
3254+
SmallVector<Value, 7> toCheck(replace.begin(), replace.end());
3255+
toCheck.push_back(none);
3256+
toCheck.push_back(skipAdd);
3257+
toCheck.push_back(sumIS);
3258+
3259+
if (failed(verifyOpsErasingOnError(toCheck, rewriter))) {
3260+
return rewriter.notifyMatchFailure(customOp, "Failed verification");
3261+
}
3262+
3263+
rewriter.replaceOp(customOp, replace);
3264+
return success();
3265+
}
3266+
};
3267+
31853268
template <typename OpToCreate>
31863269
struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpToOnnxOps {
31873270
CustomOpMicrosoftToSingleOnnxOp(MLIRContext *context,
@@ -3594,6 +3677,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
35943677
patterns.insert<MicrosoftFusedConv>(context);
35953678
patterns.insert<MicrosoftSkipLayerNorm>(context);
35963679
patterns.insert<SimplifiedLayerNorm>(context);
3680+
patterns.insert<MicrosoftSkipSimplifiedLayerNorm>(context);
35973681
patterns.insert<DecomposeSlicePadPattern>(context);
35983682
patterns.insert<DecomposeScatterNDPattern>(context);
35993683
patterns.insert<SoftmaxCrossEntropyPattern>(context);

test/mlir/onnx/onnx_decompose_customop.mlir

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,4 +514,96 @@ func.func @simplified_layernorm_two_outputs_mean_used(%input: tensor<2x4x8xf32>,
514514
// CHECK: onnx.Return [[VAR_0_]]#0, [[VAR_0_]]#1 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
515515
}
516516

517+
// -----
518+
// SkipSimplifiedLayerNormalization: 3 inputs, 1 output
519+
520+
func.func @skip_simplified_layernorm_basic(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>) -> tensor<2x4x8xf32> {
521+
%r = "onnx.Custom"(%input, %skip, %gamma) {domain_name = "com.microsoft", function_name = "SkipSimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
522+
onnx.Return %r : tensor<2x4x8xf32>
523+
// CHECK-LABEL: func.func @skip_simplified_layernorm_basic
524+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
525+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
526+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
527+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none)
528+
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
529+
}
530+
531+
532+
// -----
533+
// SkipSimplifiedLayerNormalization: 4 inputs (bias), 1 output
534+
535+
func.func @skip_simplified_layernorm_bias(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %bias: tensor<8xf32>) -> tensor<2x4x8xf32> {
536+
%r = "onnx.Custom"(%input, %skip, %gamma, %bias) {domain_name = "com.microsoft", function_name = "SkipSimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
537+
onnx.Return %r : tensor<2x4x8xf32>
538+
// CHECK-LABEL: func.func @skip_simplified_layernorm_bias
539+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
540+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
541+
// CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[VAR_1_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
542+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_2_]], [[PARAM_2_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none)
543+
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
544+
}
545+
546+
547+
548+
// -----
549+
// SkipSimplifiedLayerNormalization: 4 inputs, 2 outputs (output, mean)
550+
551+
func.func @skip_simplified_layernorm_two_outputs(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %bias: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
552+
%r0, %r1 = "onnx.Custom"(%input, %skip, %gamma, %bias) {domain_name = "com.microsoft", function_name = "SkipSimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
553+
onnx.Return %r0, %r1 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
554+
// CHECK-LABEL: func.func @skip_simplified_layernorm_two_outputs
555+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
556+
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
557+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
558+
// CHECK: [[VAR_2_:%.+]]:3 = "onnx.Custom"([[VAR_1_]], [[PARAM_2_]]) {axis = -1 : si64, domain_name = "", epsilon = 9.99999974E-6 : f32, function_name = "SimplifiedLayerNormalization", stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, none)
559+
// CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
560+
// CHECK: }
561+
}
562+
563+
564+
// -----
565+
// SkipSimplifiedLayerNormalization: 4 inputs, 3 outputs (output, mean, inv_std_var)
566+
567+
func.func @skip_simplified_layernorm_three_outputs(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %bias: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>) {
568+
%r0, %r1, %r2 = "onnx.Custom"(%input, %skip, %gamma, %bias) {domain_name = "com.microsoft", function_name = "SkipSimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
569+
onnx.Return %r0, %r1, %r2 : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>
570+
// CHECK-LABEL: func.func @skip_simplified_layernorm_three_outputs
571+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>) {
572+
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
573+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
574+
// CHECK: [[VAR_2_:%.+]]:3 = "onnx.Custom"([[VAR_1_]], [[PARAM_2_]]) {axis = -1 : si64, domain_name = "", epsilon = 9.99999974E-6 : f32, function_name = "SimplifiedLayerNormalization", stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
575+
// CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2 : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>
576+
}
577+
578+
579+
// -----
580+
// SkipSimplifiedLayerNormalization: 4 inputs, 4 outputs (output, mean, inv_std_var, sum)
581+
582+
func.func @skip_simplified_layernorm_four_outputs(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %bias: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
583+
%r0, %r1, %r2, %r3 = "onnx.Custom"(%input, %skip, %gamma, %bias) {domain_name = "com.microsoft", function_name = "SkipSimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>)
584+
onnx.Return %r0, %r1, %r2, %r3 : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
585+
// CHECK-LABEL: func.func @skip_simplified_layernorm_four_outputs
586+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
587+
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
588+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
589+
// CHECK: [[VAR_2_:%.+]]:3 = "onnx.Custom"([[VAR_1_]], [[PARAM_2_]]) {axis = -1 : si64, domain_name = "", epsilon = 9.99999974E-6 : f32, function_name = "SimplifiedLayerNormalization", stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
590+
// CHECK: onnx.Return [[VAR_2_]]#0, [[VAR_2_]]#1, [[VAR_2_]]#2, [[VAR_1_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
591+
// CHECK: }
592+
}
593+
594+
// -----
595+
// SkipSimplifiedLayerNormalization: 4 inputs, 4 outputs (output, mean, inv_std_var, sum), mean unused
596+
597+
func.func @skip_simplified_layernorm_four_outputs_mean_unused(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %bias: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
598+
%r0, %r1, %r2, %r3 = "onnx.Custom"(%input, %skip, %gamma, %bias) {domain_name = "com.microsoft", function_name = "SkipSimplifiedLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none, tensor<2x4x1xf32>, tensor<2x4x8xf32>)
599+
onnx.Return %r0, %r2, %r3 : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
600+
// CHECK-LABEL: func.func @skip_simplified_layernorm_four_outputs_mean_unused
601+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
602+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
603+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
604+
// CHECK: [[VAR_2_:%.+]] = "onnx.Add"([[VAR_1_]], [[PARAM_3_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
605+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[VAR_2_]], [[PARAM_2_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
606+
// CHECK: onnx.Return [[VAR_Y_]], [[VAR_InvStdDev_]], [[VAR_2_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
607+
}
608+
517609

0 commit comments

Comments
 (0)