Skip to content

Commit 9b043b7

Browse files
authored
Merge pull request #465 from Xilinx/move_instancenorm_pattern
Refactor: Move InstanceNorm→LayerNorm from Canonicalize to Decompose
2 parents e28cef4 + f7033c9 commit 9b043b7

File tree

12 files changed

+143
-102
lines changed

12 files changed

+143
-102
lines changed

src/Compiler/OnnxToMlirPasses.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
3030
pm.addNestedPass<func::FuncOp>(onnx_mlir::createDecomposeONNXToONNXPass(
3131
/*target=*/"", opts.enableConvTransposeDecompose,
3232
opts.enableConvTransposeDecomposeToPhasedConv,
33-
opts.enableConvTranspose1dDecomposeToPhasedConv));
33+
opts.enableConvTranspose1dDecomposeToPhasedConv,
34+
opts.enableInstanceNormDecompose));
3435
if (!opts.disableRecomposeOption)
3536
pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass(
3637
/*target=*/"", opts.enableRecomposeLayernormByTranspose));
@@ -41,7 +42,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
4142
opts.enableConvTransposeDecompose,
4243
opts.enableConvTransposeDecomposeToPhasedConv,
4344
opts.enableConvTranspose1dDecomposeToPhasedConv,
44-
opts.enableRecomposeLayernormByTranspose));
45+
opts.enableRecomposeLayernormByTranspose,
46+
opts.enableInstanceNormDecompose));
4547
// Convolution Optimization for CPU: enable when there are no accelerators.
4648
if (targetCPU && opts.enableConvOptPass) {
4749
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
@@ -52,7 +54,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
5254
opts.enableConvTransposeDecompose,
5355
opts.enableConvTransposeDecomposeToPhasedConv,
5456
opts.enableConvTranspose1dDecomposeToPhasedConv,
55-
opts.enableRecomposeLayernormByTranspose));
57+
opts.enableRecomposeLayernormByTranspose,
58+
opts.enableInstanceNormDecompose));
5659
}
5760
} else {
5861
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
@@ -109,7 +112,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
109112
opts.enableConvTransposeDecompose,
110113
opts.enableConvTransposeDecomposeToPhasedConv,
111114
opts.enableConvTranspose1dDecomposeToPhasedConv,
112-
opts.enableRecomposeLayernormByTranspose));
115+
opts.enableRecomposeLayernormByTranspose,
116+
opts.enableInstanceNormDecompose));
113117
} else {
114118
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
115119
pm.addPass(mlir::createCanonicalizerPass());

src/Compiler/OnnxToMlirPasses.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct OnnxToMlirOptions {
1616
bool enableConvTransposeDecompose = false;
1717
bool enableConvTransposeDecomposeToPhasedConv = false;
1818
bool enableConvTranspose1dDecomposeToPhasedConv = false;
19+
bool enableInstanceNormDecompose = true;
1920
bool enableRemoveDqQOp = true;
2021
bool enableRemoveDqQAroundOp = true;
2122
bool enableRemoveBinary = false;

src/Dialect/ONNX/ONNXOps.td.inc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3643,7 +3643,6 @@ def ONNXIfOp:ONNX_Op<"If",
36433643

36443644
def ONNXInstanceNormalizationOp:ONNX_Op<"InstanceNormalization",
36453645
[Pure, OpVersionTrait<22>, DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
3646-
let hasCanonicalizer = 1;
36473646
let summary = "ONNX InstanceNormalization operation";
36483647
let description = [{
36493648
Carries out instance normalization as described in the paper

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2087,68 +2087,6 @@ class RemoveWhereEqualPattern : public OpRewritePattern<ONNXWhereOp> {
20872087
}
20882088
};
20892089

2090-
// =============================================================================
2091-
// Rewrite pattern for Instance Normalization
2092-
// =============================================================================
2093-
2094-
struct RemoveInstanceNormPattern
2095-
: public OpRewritePattern<ONNXInstanceNormalizationOp> {
2096-
using OpRewritePattern<ONNXInstanceNormalizationOp>::OpRewritePattern;
2097-
2098-
static bool isDecomposable(ONNXInstanceNormalizationOp instanceNormOp) {
2099-
return onnx_mlir::hasStaticShape(instanceNormOp.getInput().getType()) &&
2100-
onnx_mlir::hasStaticShape(instanceNormOp.getOutput().getType());
2101-
}
2102-
2103-
LogicalResult matchAndRewrite(ONNXInstanceNormalizationOp instanceNormOp,
2104-
PatternRewriter &rewriter) const final {
2105-
// Match.
2106-
if (!isDecomposable(instanceNormOp)) {
2107-
return failure();
2108-
}
2109-
2110-
// Get info.
2111-
Value input = instanceNormOp.getInput();
2112-
Value scale = instanceNormOp.getScale();
2113-
Value bias = instanceNormOp.getB();
2114-
ShapedType inputType = mlir::cast<ShapedType>(input.getType());
2115-
Type elementType = inputType.getElementType();
2116-
auto inputShape = inputType.getShape();
2117-
int64_t C = inputShape[1];
2118-
int64_t inputRank = inputType.getRank();
2119-
int64_t nonSpacialRank = 2; // Batch N and Channel C: 2 dimensions.
2120-
assert(inputRank > nonSpacialRank &&
2121-
"expected instance norm with input ranks > 2");
2122-
2123-
// Rewrite.
2124-
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
2125-
rewriter, instanceNormOp.getLoc());
2126-
int64_t axis = nonSpacialRank;
2127-
int64_t numInNorm = inputRank - axis;
2128-
// Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm
2129-
// 1s.
2130-
llvm::SmallVector<int64_t, 4> axesList, biasScaleShape;
2131-
biasScaleShape.emplace_back(C);
2132-
for (int64_t i = 1; i <= numInNorm; ++i) {
2133-
biasScaleShape.emplace_back(1);
2134-
axesList.emplace_back(i);
2135-
}
2136-
Value axes = create.onnx.constantInt64(axesList);
2137-
Type biasScaleType = RankedTensorType::get(biasScaleShape, elementType);
2138-
Value newScale = create.onnx.unsqueeze(biasScaleType, scale, axes);
2139-
Value newBias = create.onnx.unsqueeze(biasScaleType, bias, axes);
2140-
// Create output using layer norm.
2141-
Value Y = create.onnx.layerNorm(inputType, input, newScale, newBias, axis,
2142-
instanceNormOp.getEpsilonAttr());
2143-
// Set the type of the output to be the same as the output of the original
2144-
// operation we are trying to replace.
2145-
Y.setType(instanceNormOp.getResult().getType());
2146-
// Replace operation.
2147-
rewriter.replaceOp(instanceNormOp, Y);
2148-
return success();
2149-
}
2150-
};
2151-
21522090
// =============================================================================
21532091
// Rewrite pattern for Group Normalization
21542092
// =============================================================================
@@ -2614,12 +2552,6 @@ void ONNXIdentityOp::getCanonicalizationPatterns(
26142552
results.insert<IdentityEliminationPattern>(context);
26152553
}
26162554

2617-
/// on the ONNXInstanceNormalizationOp.
2618-
void ONNXInstanceNormalizationOp::getCanonicalizationPatterns(
2619-
RewritePatternSet &results, MLIRContext *context) {
2620-
results.insert<RemoveInstanceNormPattern>(context);
2621-
}
2622-
26232555
/// on the ONNXLayoutTransformOp.
26242556
void ONNXLayoutTransformOp::getCanonicalizationPatterns(
26252557
RewritePatternSet &result, MLIRContext *context) {

src/Dialect/ONNX/Transforms/Decompose.cpp

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3512,6 +3512,67 @@ class ReplaceCastLikeByCastPattern : public OpRewritePattern<ONNXCastLikeOp> {
35123512
}
35133513
};
35143514

3515+
// =============================================================================
3516+
// Decompose InstanceNormalization to LayerNormalization
3517+
// =============================================================================
3518+
struct DecomposeInstanceNormPattern
3519+
: public OpRewritePattern<ONNXInstanceNormalizationOp> {
3520+
using OpRewritePattern<ONNXInstanceNormalizationOp>::OpRewritePattern;
3521+
3522+
static bool isDecomposable(ONNXInstanceNormalizationOp instanceNormOp) {
3523+
return onnx_mlir::hasStaticShape(instanceNormOp.getInput().getType()) &&
3524+
onnx_mlir::hasStaticShape(instanceNormOp.getOutput().getType());
3525+
}
3526+
3527+
LogicalResult matchAndRewrite(ONNXInstanceNormalizationOp instanceNormOp,
3528+
PatternRewriter &rewriter) const final {
3529+
// Match.
3530+
if (!isDecomposable(instanceNormOp)) {
3531+
return failure();
3532+
}
3533+
3534+
// Get info.
3535+
Value input = instanceNormOp.getInput();
3536+
Value scale = instanceNormOp.getScale();
3537+
Value bias = instanceNormOp.getB();
3538+
ShapedType inputType = mlir::cast<ShapedType>(input.getType());
3539+
Type elementType = inputType.getElementType();
3540+
auto inputShape = inputType.getShape();
3541+
int64_t C = inputShape[1];
3542+
int64_t inputRank = inputType.getRank();
3543+
int64_t nonSpacialRank = 2; // Batch N and Channel C: 2 dimensions.
3544+
assert(inputRank > nonSpacialRank &&
3545+
"expected instance norm with input ranks > 2");
3546+
3547+
// Rewrite.
3548+
onnx_mlir::MultiDialectBuilder<onnx_mlir::OnnxBuilder> create(
3549+
rewriter, instanceNormOp.getLoc());
3550+
int64_t axis = nonSpacialRank;
3551+
int64_t numInNorm = inputRank - axis;
3552+
// Unsqueeze scale/bias from [C] to [C x 1 x 1 x ... x 1] with numInNorm
3553+
// 1s.
3554+
llvm::SmallVector<int64_t, 4> axesList, biasScaleShape;
3555+
biasScaleShape.emplace_back(C);
3556+
for (int64_t i = 1; i <= numInNorm; ++i) {
3557+
biasScaleShape.emplace_back(1);
3558+
axesList.emplace_back(i);
3559+
}
3560+
Value axes = create.onnx.constantInt64(axesList);
3561+
Type biasScaleType = RankedTensorType::get(biasScaleShape, elementType);
3562+
Value newScale = create.onnx.unsqueeze(biasScaleType, scale, axes);
3563+
Value newBias = create.onnx.unsqueeze(biasScaleType, bias, axes);
3564+
// Create output using layer norm.
3565+
Value Y = create.onnx.layerNorm(inputType, input, newScale, newBias, axis,
3566+
instanceNormOp.getEpsilonAttr());
3567+
// Set the type of the output to be the same as the output of the original
3568+
// operation we are trying to replace.
3569+
Y.setType(instanceNormOp.getResult().getType());
3570+
// Replace operation.
3571+
rewriter.replaceOp(instanceNormOp, Y);
3572+
return success();
3573+
}
3574+
};
3575+
35153576
// =============================================================================
35163577
// Decompose Hardswish to simpler ONNX ops
35173578
// =============================================================================
@@ -3577,13 +3638,15 @@ struct DecomposeONNXToONNXPass
35773638
DecomposeONNXToONNXPass(const std::string &target,
35783639
bool enableConvTransposeDecompose = false,
35793640
bool enableConvTransposeDecomposeToPhasedConv = false,
3580-
bool enableConvTranspose1dDecomposeToPhasedConv = false) {
3641+
bool enableConvTranspose1dDecomposeToPhasedConv = false,
3642+
bool enableInstanceNormDecompose = true) {
35813643
this->target = target;
35823644
this->enableConvTransposeDecompose = enableConvTransposeDecompose;
35833645
this->enableConvTransposeDecomposeToPhasedConv =
35843646
enableConvTransposeDecomposeToPhasedConv;
35853647
this->enableConvTranspose1dDecomposeToPhasedConv =
35863648
enableConvTranspose1dDecomposeToPhasedConv;
3649+
this->enableInstanceNormDecompose = enableInstanceNormDecompose;
35873650
}
35883651

35893652
DecomposeONNXToONNXPass(const DecomposeONNXToONNXPass &pass)
@@ -3594,6 +3657,8 @@ struct DecomposeONNXToONNXPass
35943657
pass.enableConvTransposeDecompose.getValue();
35953658
this->enableConvTransposeDecomposeToPhasedConv =
35963659
pass.enableConvTransposeDecomposeToPhasedConv.getValue();
3660+
this->enableInstanceNormDecompose =
3661+
pass.enableInstanceNormDecompose.getValue();
35973662
}
35983663

35993664
StringRef getArgument() const override { return "decompose-onnx"; }
@@ -3623,6 +3688,12 @@ struct DecomposeONNXToONNXPass
36233688
"phased Conv"),
36243689
::llvm::cl::init(false)};
36253690

3691+
Option<bool> enableInstanceNormDecompose{*this,
3692+
"enable-instancenorm-decompose",
3693+
llvm::cl::desc("Enable decomposition of InstanceNormalization to "
3694+
"LayerNormalization"),
3695+
::llvm::cl::init(true)};
3696+
36263697
void runOnOperation() final;
36273698

36283699
typedef PassWrapper<DecomposeONNXToONNXPass, OperationPass<func::FuncOp>>
@@ -3635,7 +3706,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
36353706
RewritePatternSet patterns(context);
36363707
onnx_mlir::getDecomposeONNXToONNXPatterns(patterns,
36373708
enableConvTransposeDecompose, enableConvTransposeDecomposeToPhasedConv,
3638-
enableConvTranspose1dDecomposeToPhasedConv);
3709+
enableConvTranspose1dDecomposeToPhasedConv, enableInstanceNormDecompose);
36393710
patterns.insert<ReplaceCastLikeByCastPattern>(context);
36403711

36413712
#ifdef ONNX_MLIR_ENABLE_STABLEHLO
@@ -3653,7 +3724,8 @@ void DecomposeONNXToONNXPass::runOnOperation() {
36533724
void onnx_mlir::getDecomposeONNXToONNXPatterns(
36543725
mlir::RewritePatternSet &patterns, bool enableConvTransposeDecompose,
36553726
bool enableConvTransposeDecomposeToPhasedConv,
3656-
bool enableConvTranspose1dDecomposeToPhasedConv) {
3727+
bool enableConvTranspose1dDecomposeToPhasedConv,
3728+
bool enableInstanceNormDecompose) {
36573729
MLIRContext *context = patterns.getContext();
36583730
populateWithGenerated(patterns);
36593731
if (enableConvTransposeDecompose)
@@ -3662,6 +3734,8 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
36623734
convtranspose_phased::populateWithGenerated(patterns);
36633735
if (enableConvTranspose1dDecomposeToPhasedConv)
36643736
convtranspose_1d_phased::populateWithGenerated(patterns);
3737+
if (enableInstanceNormDecompose)
3738+
patterns.insert<DecomposeInstanceNormPattern>(context);
36653739
patterns.insert<onnx_mlir::DecomposeEinsumPattern>(context);
36663740
patterns.insert<ConcatFusePattern>(context);
36673741
patterns.insert<DecomposeHardSwishPattern>(context);
@@ -3699,8 +3773,9 @@ void onnx_mlir::getDecomposeONNXToONNXPatterns(
36993773
std::unique_ptr<mlir::Pass> onnx_mlir::createDecomposeONNXToONNXPass(
37003774
const std::string &target, bool enableConvTransposeDecompose,
37013775
bool enableConvTransposeDecomposeToPhasedConv,
3702-
bool enableConvTranspose1dDecomposeToPhasedConv) {
3776+
bool enableConvTranspose1dDecomposeToPhasedConv,
3777+
bool enableInstanceNormDecompose) {
37033778
return std::make_unique<DecomposeONNXToONNXPass>(target,
37043779
enableConvTransposeDecompose, enableConvTransposeDecomposeToPhasedConv,
3705-
enableConvTranspose1dDecomposeToPhasedConv);
3780+
enableConvTranspose1dDecomposeToPhasedConv, enableInstanceNormDecompose);
37063781
}

src/Dialect/ONNX/Transforms/Decompose.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ namespace onnx_mlir {
2929
void getDecomposeONNXToONNXPatterns(mlir::RewritePatternSet &patterns,
3030
bool enableConvTransposeDecompose,
3131
bool enableConvTransposeDecomposeToPhasedConv,
32-
bool enableConvTranspose1dDecomposeToPhasedConv);
32+
bool enableConvTranspose1dDecomposeToPhasedConv,
33+
bool enableInstanceNormDecompose);
3334

3435
} // namespace onnx_mlir
3536
#endif

src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ struct ONNXHybridTransformPass
111111
"phased Conv"),
112112
::llvm::cl::init(false)};
113113

114+
Option<bool> enableInstanceNormDecompose{*this,
115+
"enable-instancenorm-decompose",
116+
llvm::cl::desc("Enable decomposition of InstanceNormalization to "
117+
"LayerNormalization"),
118+
::llvm::cl::init(true)};
119+
114120
Option<bool> recomposeLayernormByTranspose{*this,
115121
"recompose-layernorm-by-transpose",
116122
llvm::cl::desc("Use transpose operator to make unsuitable axes suitable "
@@ -124,7 +130,7 @@ struct ONNXHybridTransformPass
124130
bool enableConvTransposeDecompose,
125131
bool enableConvTransposeDecomposeToPhasedConv,
126132
bool enableConvTranspose1dDecomposeToPhasedConv,
127-
bool recomposeLayernormByTranspose) {
133+
bool recomposeLayernormByTranspose, bool enableInstanceNormDecompose) {
128134
this->recomposition = enableRecomposition;
129135
this->quarkQuantizedOpsLegalization = enableQuarkQuantizedOpsLegalization;
130136
this->enableConvTransposeDecompose = enableConvTransposeDecompose;
@@ -133,6 +139,7 @@ struct ONNXHybridTransformPass
133139
this->enableConvTranspose1dDecomposeToPhasedConv =
134140
enableConvTranspose1dDecomposeToPhasedConv;
135141
this->recomposeLayernormByTranspose = recomposeLayernormByTranspose;
142+
this->enableInstanceNormDecompose = enableInstanceNormDecompose;
136143
}
137144

138145
ONNXHybridTransformPass(const ONNXHybridTransformPass &pass)
@@ -175,7 +182,8 @@ struct ONNXHybridTransformPass
175182
getDecomposeONNXToONNXPatterns(cumulativePatterns,
176183
enableConvTransposeDecompose,
177184
enableConvTransposeDecomposeToPhasedConv,
178-
enableConvTranspose1dDecomposeToPhasedConv);
185+
enableConvTranspose1dDecomposeToPhasedConv,
186+
enableInstanceNormDecompose);
179187
}
180188

181189
if (recomposition) {
@@ -220,10 +228,11 @@ std::unique_ptr<mlir::Pass> onnx_mlir::createONNXHybridTransformPass(
220228
bool enableConvTransposeDecompose,
221229
bool enableConvTransposeDecomposeToPhasedConv,
222230
bool enableConvTranspose1dDecomposeToPhasedConv,
223-
bool enableRecomposeLayernormByTranspose) {
231+
bool enableRecomposeLayernormByTranspose,
232+
bool enableInstanceNormDecompose) {
224233
return std::make_unique<ONNXHybridTransformPass>(enableRecomposition,
225234
enableQuarkQuantizedOpsLegalization, enableConvTransposeDecompose,
226235
enableConvTransposeDecomposeToPhasedConv,
227236
enableConvTranspose1dDecomposeToPhasedConv,
228-
enableRecomposeLayernormByTranspose);
237+
enableRecomposeLayernormByTranspose, enableInstanceNormDecompose);
229238
}

src/Pass/Passes.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ std::unique_ptr<mlir::Pass> createONNXOpTransformPass(int threshold,
4040
std::unique_ptr<mlir::Pass> createDecomposeONNXToONNXPass(
4141
const std::string &target = "", bool enableConvTransposeDecompose = false,
4242
bool enableConvTransposeDecomposeToPhasedConv = false,
43-
bool enableConvTranspose1dDecomposeToPhasedConv = false);
43+
bool enableConvTranspose1dDecomposeToPhasedConv = false,
44+
bool enableInstanceNormDecompose = true);
4445
std::unique_ptr<mlir::Pass> createRecomposeONNXToONNXPass(
4546
const std::string &target = "",
4647
const bool &recomposeLayernormByTranspose = false);
@@ -86,7 +87,8 @@ std::unique_ptr<mlir::Pass> createONNXHybridTransformPass(
8687
bool enableConvTransposeDecompose = false,
8788
bool enableConvTransposeDecomposeToPhasedConv = false,
8889
bool enableConvTranspose1dDecomposeToPhasedConv = false,
89-
bool enableRecomposeLayernormByTranspose = false);
90+
bool enableRecomposeLayernormByTranspose = false,
91+
bool enableInstanceNormDecompose = true);
9092

9193
/// Pass for analyzing unknown dimension in ONNX operations.
9294
std::unique_ptr<mlir::Pass> createONNXDimAnalysisPass();

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2133,22 +2133,6 @@ func.func @test_reorder_relu_maxpool(%arg0: tensor<1x64x32x32xf32>) -> tensor<1x
21332133

21342134
// -----
21352135

2136-
func.func @test_instancenorm(%arg0: tensor<2x3x4x5x6xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<2x3x4x5x6xf32> {
2137-
%0 = "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32} : (tensor<2x3x4x5x6xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<2x3x4x5x6xf32>
2138-
onnx.Return %0 : tensor<2x3x4x5x6xf32>
2139-
// mlir2FileCheck.py
2140-
// CHECK-LABEL: func.func @test_instancenorm
2141-
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<2x3x4x5x6xf32>, [[PARAM_1_:%.+]]: tensor<3xf32>, [[PARAM_2_:%.+]]: tensor<3xf32>) -> tensor<2x3x4x5x6xf32> {
2142-
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 2, 3]> : tensor<3xi64>
2143-
// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.Unsqueeze"([[PARAM_1_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xi64>) -> tensor<3x1x1x1xf32>
2144-
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Unsqueeze"([[PARAM_2_]], [[VAR_0_]]) : (tensor<3xf32>, tensor<3xi64>) -> tensor<3x1x1x1xf32>
2145-
// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_2_]]) {axis = 2 : si64, epsilon = 0.00999999977 : f32, stash_type = 1 : si64} : (tensor<2x3x4x5x6xf32>, tensor<3x1x1x1xf32>, tensor<3x1x1x1xf32>) -> (tensor<2x3x4x5x6xf32>, none, none)
2146-
// CHECK: onnx.Return [[Y_]] : tensor<2x3x4x5x6xf32>
2147-
// CHECK: }
2148-
}
2149-
2150-
// -----
2151-
21522136
func.func @test_groupnorm_v18(%arg0: tensor<3x4x2x2xf32>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<3x4x2x2xf32> {
21532137
%0 = "onnx.GroupNormalizationV18"(%arg0, %arg1, %arg2) {epsilon = 0.00999999977 : f32, num_groups = 2 : si64} : (tensor<3x4x2x2xf32>, tensor<2xf32>, tensor<2xf32>) -> tensor<3x4x2x2xf32>
21542138
onnx.Return %0 : tensor<3x4x2x2xf32>

0 commit comments

Comments
 (0)