Skip to content

Commit 24f4192

Browse files
committed
Add decomposition of microsoft SkipLayerNorm
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent 2dbf4f3 commit 24f4192

File tree

2 files changed

+181
-3
lines changed

2 files changed

+181
-3
lines changed

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <cmath>
2424
#include <numeric>
2525

26+
#include "mlir/Analysis/TopologicalSortUtils.h"
2627
#include "mlir/IR/Attributes.h"
2728
#include "mlir/IR/BuiltinAttributes.h"
2829
#include "mlir/IR/BuiltinTypes.h"
@@ -2875,10 +2876,20 @@ struct CustomOpMicrosoftToOnnxOps : public OpRewritePattern<ONNXCustomOp> {
28752876
if (llvm::any_of(values, [](Value value) {
28762877
return value && failed(verifyOpValidity(value.getDefiningOp()));
28772878
})) {
2878-
for (auto value : values)
2879+
SmallVector<Operation *> opsToErase;
2880+
for (auto value : values) {
28792881
if (value) {
2880-
rewriter.eraseOp(value.getDefiningOp());
2882+
opsToErase.push_back(value.getDefiningOp());
28812883
}
2884+
}
2885+
llvm::sort(opsToErase);
2886+
opsToErase.erase(llvm::unique(opsToErase), opsToErase.end());
2887+
// We need to ensure that the ops get erased in reverse topological order,
2888+
// as its only allowed to erase an op if it does not have an use
2889+
computeTopologicalSorting(opsToErase);
2890+
for (auto *op : llvm::reverse(opsToErase)) {
2891+
rewriter.eraseOp(op);
2892+
}
28822893
return failure();
28832894
}
28842895
return success();
@@ -3022,6 +3033,83 @@ struct MicrosoftFusedConv : public CustomOpMicrosoftToOnnxOps {
30223033
}
30233034
};
30243035

3036+
struct MicrosoftSkipLayerNorm : public CustomOpMicrosoftToOnnxOps {
3037+
MicrosoftSkipLayerNorm(MLIRContext *ctx, PatternBenefit b = 1)
3038+
: CustomOpMicrosoftToOnnxOps(ctx, "SkipLayerNormalization", b) {}
3039+
3040+
LogicalResult matchAndRewriteImpl(
3041+
ONNXCustomOp customOp, PatternRewriter &rewriter) const final {
3042+
using namespace onnx_mlir;
3043+
Location loc = customOp.getLoc();
3044+
const int64_t numIn = customOp.getNumOperands();
3045+
assert((numIn >= 3 && numIn <= 5) && "expects 3..5 inputs");
3046+
const int64_t numOut = customOp.getNumResults();
3047+
assert((numOut >= 1 && numOut <= 4) && "expects 1..4 outputs");
3048+
3049+
MultiDialectBuilder<OnnxBuilder> create(rewriter, customOp->getLoc());
3050+
3051+
Value none = create.onnx.none();
3052+
3053+
Value input = customOp.getOperand(0);
3054+
Value skip = customOp.getOperand(1);
3055+
Value gamma = customOp.getOperand(2);
3056+
Value beta = none; // layer-norm bias
3057+
Value bias; // pre-norm bias
3058+
3059+
if (numIn >= 4)
3060+
beta = customOp.getOperand(3);
3061+
if (numIn == 5)
3062+
bias = customOp.getOperand(4);
3063+
3064+
auto epsAttr = customOp->getAttrOfType<FloatAttr>("epsilon");
3065+
assert(epsAttr && "Expected Epsilon");
3066+
3067+
Value skipAdd = create.onnx.add(input, skip);
3068+
Value sumIS;
3069+
if (bias) {
3070+
sumIS = create.onnx.add(skipAdd, bias);
3071+
} else {
3072+
sumIS = skipAdd;
3073+
skipAdd = nullptr;
3074+
}
3075+
3076+
SmallVector<Type, 3> resultTypes;
3077+
resultTypes.push_back(customOp->getResultTypes()[0]);
3078+
resultTypes.push_back(
3079+
numOut > 1 ? customOp->getResultTypes()[1] : rewriter.getNoneType());
3080+
resultTypes.push_back(
3081+
numOut > 2 ? customOp->getResultTypes()[2] : rewriter.getNoneType());
3082+
3083+
const auto si64Type = rewriter.getIntegerType(64, /*signed*/ true);
3084+
3085+
auto ln = rewriter.create<ONNXLayerNormalizationOp>(loc, resultTypes, sumIS,
3086+
gamma, beta, /*axis*/
3087+
rewriter.getIntegerAttr(si64Type, -1), epsAttr,
3088+
/*stashType*/ rewriter.getIntegerAttr(si64Type, 1));
3089+
3090+
SmallVector<Value, 4> replace;
3091+
replace.push_back(ln.getResult(0));
3092+
if (numOut >= 2)
3093+
replace.push_back(ln.getResult(1)); // mean
3094+
if (numOut >= 3)
3095+
replace.push_back(ln.getResult(2)); // inv_std_var
3096+
if (numOut == 4)
3097+
replace.push_back(sumIS); // input_skip_bias_sum
3098+
3099+
SmallVector<Value, 6> toCheck(replace.begin(), replace.end());
3100+
toCheck.push_back(none);
3101+
toCheck.push_back(skipAdd);
3102+
toCheck.push_back(sumIS);
3103+
3104+
if (failed(verifyOpsErasingOnError(toCheck, rewriter))) {
3105+
return rewriter.notifyMatchFailure(customOp, "Failed verification");
3106+
}
3107+
3108+
rewriter.replaceOp(customOp, replace);
3109+
return success();
3110+
}
3111+
};
3112+
30253113
template <typename OpToCreate>
30263114
struct CustomOpMicrosoftToSingleOnnxOp : public CustomOpMicrosoftToOnnxOps {
30273115
using CustomOpMicrosoftToOnnxOps::CustomOpMicrosoftToOnnxOps;
@@ -3429,6 +3517,7 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
34293517
patterns.insert<CustomOpMicrosoftToSingleOnnxOp<ONNXGeluOp>>(context, "Gelu");
34303518
patterns.insert<MicrosoftBiasGelu>(context);
34313519
patterns.insert<MicrosoftFusedConv>(context);
3520+
patterns.insert<MicrosoftSkipLayerNorm>(context);
34323521
patterns.insert<DecomposeSlicePadPattern>(context);
34333522
patterns.insert<DecomposeScatterNDPattern>(context);
34343523
patterns.insert<SoftmaxCrossEntropyPattern>(context);

test/mlir/onnx/onnx_decompose_customop.mlir

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,4 +357,93 @@ func.func @fusedconv_too_many_operands(%x: tensor<1x3x8x8xf32>, %w: tensor<4x3x3
357357
// CHECK: onnx.Return [[VAR_0_]] : tensor<1x4x8x8xf32>
358358
// CHECK: }
359359

360-
}
360+
}
361+
362+
// -----
363+
// SkipLayerNormalization: 3 inputs, 1 output
364+
365+
func.func @skip_layernorm_basic(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>) -> tensor<2x4x8xf32> {
366+
%r = "onnx.Custom"(%input, %skip, %gamma) {domain_name = "com.microsoft", function_name = "SkipLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
367+
onnx.Return %r : tensor<2x4x8xf32>
368+
// CHECK-LABEL: func.func @skip_layernorm_basic
369+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
370+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
371+
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
372+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[VAR_0_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, none) -> (tensor<2x4x8xf32>, none, none)
373+
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
374+
}
375+
376+
377+
// -----
378+
// SkipLayerNormalization: 4 inputs (beta), 1 output
379+
380+
func.func @skip_layernorm_beta(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %beta: tensor<8xf32>) -> tensor<2x4x8xf32> {
381+
%r = "onnx.Custom"(%input, %skip, %gamma, %beta) {domain_name = "com.microsoft", function_name = "SkipLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
382+
onnx.Return %r : tensor<2x4x8xf32>
383+
// CHECK-LABEL: func.func @skip_layernorm_beta
384+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
385+
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
386+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_0_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none, none)
387+
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
388+
}
389+
390+
391+
// -----
392+
// SkipLayerNormalization: 5 inputs (beta + bias), 1 output
393+
394+
func.func @skip_layernorm_beta_bias(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %beta: tensor<8xf32>, %bias: tensor<8xf32>) -> tensor<2x4x8xf32> {
395+
%r = "onnx.Custom"(%input, %skip, %gamma, %beta, %bias) {domain_name = "com.microsoft", function_name = "SkipLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
396+
onnx.Return %r : tensor<2x4x8xf32>
397+
// CHECK-LABEL: func.func @skip_layernorm_beta_bias
398+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>, [[PARAM_4_:%.+]]: tensor<8xf32>) -> tensor<2x4x8xf32> {
399+
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
400+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_4_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
401+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, none, none)
402+
// CHECK: onnx.Return [[VAR_Y_]] : tensor<2x4x8xf32>
403+
}
404+
405+
406+
// -----
407+
// SkipLayerNormalization: 5 inputs, 2 outputs (output, mean)
408+
409+
func.func @skip_layernorm_two_outputs(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %beta: tensor<8xf32>, %bias: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
410+
%r0, %r1 = "onnx.Custom"(%input, %skip, %gamma, %beta, %bias) {domain_name = "com.microsoft", function_name = "SkipLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>)
411+
onnx.Return %r0, %r1 : tensor<2x4x8xf32>, tensor<2x4x1xf32>
412+
// CHECK-LABEL: func.func @skip_layernorm_two_outputs
413+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>, [[PARAM_4_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>) {
414+
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
415+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_4_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
416+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, none)
417+
// CHECK: onnx.Return [[VAR_Y_]], [[VAR_Mean_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>
418+
}
419+
420+
421+
// -----
422+
// SkipLayerNormalization: 5 inputs, 3 outputs (output, mean, inv_std_var)
423+
424+
func.func @skip_layernorm_three_outputs(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %beta: tensor<8xf32>, %bias: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>) {
425+
%r0, %r1, %r2 = "onnx.Custom"(%input, %skip, %gamma, %beta, %bias) {domain_name = "com.microsoft", function_name = "SkipLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
426+
onnx.Return %r0, %r1, %r2 : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>
427+
// CHECK-LABEL: func.func @skip_layernorm_three_outputs
428+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>, [[PARAM_4_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>) {
429+
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
430+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_4_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
431+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
432+
// CHECK: onnx.Return [[VAR_Y_]], [[VAR_Mean_]], [[VAR_InvStdDev_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>
433+
}
434+
435+
436+
// -----
437+
// SkipLayerNormalization: 5 inputs, 4 outputs (output, mean, inv_std_var, sum)
438+
439+
func.func @skip_layernorm_four_outputs(%input: tensor<2x4x8xf32>, %skip: tensor<2x4x8xf32>, %gamma: tensor<8xf32>, %beta: tensor<8xf32>, %bias: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
440+
%r0, %r1, %r2, %r3 = "onnx.Custom"(%input, %skip, %gamma, %beta, %bias) {domain_name = "com.microsoft", function_name = "SkipLayerNormalization", epsilon = 1.000000e-05 : f32} : (tensor<2x4x8xf32>, tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>)
441+
onnx.Return %r0, %r1, %r2, %r3 : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
442+
// CHECK-LABEL: func.func @skip_layernorm_four_outputs
443+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x4x8xf32>, [[PARAM_1_:%.+]]: tensor<2x4x8xf32>, [[PARAM_2_:%.+]]: tensor<8xf32>, [[PARAM_3_:%.+]]: tensor<8xf32>, [[PARAM_4_:%.+]]: tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>) {
444+
// CHECK: [[VAR_0_:%.+]] = "onnx.Add"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<2x4x8xf32>, tensor<2x4x8xf32>) -> tensor<2x4x8xf32>
445+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_4_]]) : (tensor<2x4x8xf32>, tensor<8xf32>) -> tensor<2x4x8xf32>
446+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[VAR_1_]], [[PARAM_2_]], [[PARAM_3_]]) {axis = -1 : si64, epsilon = 9.99999974E-6 : f32, stash_type = 1 : si64} : (tensor<2x4x8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>)
447+
// CHECK: onnx.Return [[VAR_Y_]], [[VAR_Mean_]], [[VAR_InvStdDev_]], [[VAR_1_]] : tensor<2x4x8xf32>, tensor<2x4x1xf32>, tensor<2x4x1xf32>, tensor<2x4x8xf32>
448+
}
449+

0 commit comments

Comments
 (0)