Skip to content

Commit b7dc57c

Browse files
committed
Address review comments: move enableInstanceNormDecompose to end and add e2e test
1 parent e6a3485 commit b7dc57c

File tree

4 files changed

+33
-14
lines changed

4 files changed

+33
-14
lines changed

src/Compiler/OnnxToMlirPasses.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
4242
opts.enableConvTransposeDecompose,
4343
opts.enableConvTransposeDecomposeToPhasedConv,
4444
opts.enableConvTranspose1dDecomposeToPhasedConv,
45-
opts.enableInstanceNormDecompose,
46-
opts.enableRecomposeLayernormByTranspose));
45+
opts.enableRecomposeLayernormByTranspose,
46+
opts.enableInstanceNormDecompose));
4747
// Convolution Optimization for CPU: enable when there are no accelerators.
4848
if (targetCPU && opts.enableConvOptPass) {
4949
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
@@ -54,8 +54,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
5454
opts.enableConvTransposeDecompose,
5555
opts.enableConvTransposeDecomposeToPhasedConv,
5656
opts.enableConvTranspose1dDecomposeToPhasedConv,
57-
opts.enableInstanceNormDecompose,
58-
opts.enableRecomposeLayernormByTranspose));
57+
opts.enableRecomposeLayernormByTranspose,
58+
opts.enableInstanceNormDecompose));
5959
}
6060
} else {
6161
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
@@ -112,8 +112,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
112112
opts.enableConvTransposeDecompose,
113113
opts.enableConvTransposeDecomposeToPhasedConv,
114114
opts.enableConvTranspose1dDecomposeToPhasedConv,
115-
opts.enableInstanceNormDecompose,
116-
opts.enableRecomposeLayernormByTranspose));
115+
opts.enableRecomposeLayernormByTranspose,
116+
opts.enableInstanceNormDecompose));
117117
} else {
118118
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
119119
pm.addPass(mlir::createCanonicalizerPass());

src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,16 @@ struct ONNXHybridTransformPass
130130
bool enableConvTransposeDecompose,
131131
bool enableConvTransposeDecomposeToPhasedConv,
132132
bool enableConvTranspose1dDecomposeToPhasedConv,
133-
bool enableInstanceNormDecompose, bool recomposeLayernormByTranspose) {
133+
bool recomposeLayernormByTranspose, bool enableInstanceNormDecompose) {
134134
this->recomposition = enableRecomposition;
135135
this->quarkQuantizedOpsLegalization = enableQuarkQuantizedOpsLegalization;
136136
this->enableConvTransposeDecompose = enableConvTransposeDecompose;
137137
this->enableConvTransposeDecomposeToPhasedConv =
138138
enableConvTransposeDecomposeToPhasedConv;
139139
this->enableConvTranspose1dDecomposeToPhasedConv =
140140
enableConvTranspose1dDecomposeToPhasedConv;
141-
this->enableInstanceNormDecompose = enableInstanceNormDecompose;
142141
this->recomposeLayernormByTranspose = recomposeLayernormByTranspose;
142+
this->enableInstanceNormDecompose = enableInstanceNormDecompose;
143143
}
144144

145145
ONNXHybridTransformPass(const ONNXHybridTransformPass &pass)
@@ -228,11 +228,11 @@ std::unique_ptr<mlir::Pass> onnx_mlir::createONNXHybridTransformPass(
228228
bool enableConvTransposeDecompose,
229229
bool enableConvTransposeDecomposeToPhasedConv,
230230
bool enableConvTranspose1dDecomposeToPhasedConv,
231-
bool enableInstanceNormDecompose,
232-
bool enableRecomposeLayernormByTranspose) {
231+
bool enableRecomposeLayernormByTranspose,
232+
bool enableInstanceNormDecompose) {
233233
return std::make_unique<ONNXHybridTransformPass>(enableRecomposition,
234234
enableQuarkQuantizedOpsLegalization, enableConvTransposeDecompose,
235235
enableConvTransposeDecomposeToPhasedConv,
236-
enableConvTranspose1dDecomposeToPhasedConv, enableInstanceNormDecompose,
237-
enableRecomposeLayernormByTranspose);
236+
enableConvTranspose1dDecomposeToPhasedConv,
237+
enableRecomposeLayernormByTranspose, enableInstanceNormDecompose);
238238
}

src/Pass/Passes.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ std::unique_ptr<mlir::Pass> createONNXHybridTransformPass(
8787
bool enableConvTransposeDecompose = false,
8888
bool enableConvTransposeDecomposeToPhasedConv = false,
8989
bool enableConvTranspose1dDecomposeToPhasedConv = false,
90-
bool enableInstanceNormDecompose = true,
91-
bool enableRecomposeLayernormByTranspose = false);
90+
bool enableRecomposeLayernormByTranspose = false,
91+
bool enableInstanceNormDecompose = true);
9292

9393
/// Pass for analyzing unknown dimension in ONNX operations.
9494
std::unique_ptr<mlir::Pass> createONNXDimAnalysisPass();
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: onnx-mlir --EmitONNXBasic --printIR %s | FileCheck %s
2+
<
3+
ir_version: 8,
4+
opset_import: ["" : 6]
5+
>
6+
test_instancenorm_e2e (float[2,3,4,5] input, float[3] scale, float[3] bias) => (float[2,3,4,5] output) {
7+
output = InstanceNormalization <epsilon: float = 0.01> (input, scale, bias)
8+
}
9+
10+
// mlir2FileCheck.py
11+
// CHECK-LABEL: func.func @main_graph
12+
// CHECK-SAME: ([[INPUT:%.+]]: tensor<2x3x4x5xf32> {onnx.name = "input"}, [[SCALE:%.+]]: tensor<3xf32> {onnx.name = "scale"}, [[BIAS:%.+]]: tensor<3xf32> {onnx.name = "bias"}) -> (tensor<2x3x4x5xf32> {onnx.name = "output"}) {
13+
// CHECK-DAG: [[AXES:%.+]] = onnx.Constant dense<[1, 2]> : tensor<2xi64>
14+
// CHECK-DAG: [[UNSQUEEZE_SCALE:%.+]] = "onnx.Unsqueeze"([[SCALE]], [[AXES]]) : (tensor<3xf32>, tensor<2xi64>) -> tensor<3x1x1xf32>
15+
// CHECK-DAG: [[UNSQUEEZE_BIAS:%.+]] = "onnx.Unsqueeze"([[BIAS]], [[AXES]]) : (tensor<3xf32>, tensor<2xi64>) -> tensor<3x1x1xf32>
16+
// CHECK: [[OUTPUT:%.+]], [[MEAN:%.+]], [[INV_STD_DEV:%.+]] = "onnx.LayerNormalization"([[INPUT]], [[UNSQUEEZE_SCALE]], [[UNSQUEEZE_BIAS]]) {axis = 2 : si64, epsilon = 1.000000e-02 : f32, stash_type = 1 : si64} : (tensor<2x3x4x5xf32>, tensor<3x1x1xf32>, tensor<3x1x1xf32>) -> (tensor<2x3x4x5xf32>, none, none)
17+
// CHECK: onnx.Return [[OUTPUT]] : tensor<2x3x4x5xf32>
18+
// CHECK: }
19+

0 commit comments

Comments
 (0)