@@ -111,20 +111,28 @@ struct ONNXHybridTransformPass
111111 " phased Conv" ),
112112 ::llvm::cl::init (false )};
113113
114+ Option<bool > recomposeLayernormByTranspose{*this ,
115+ " recompose-layernorm-by-transpose" ,
116+ llvm::cl::desc (" Use transpose operator to make unsuitable axes suitable "
117+ " for matching layernorm" ),
118+ ::llvm::cl::init (false )};
119+
114120 FrozenRewritePatternSet patterns;
115121
116122 ONNXHybridTransformPass (bool enableRecomposition,
117123 bool enableQuarkQuantizedOpsLegalization,
118124 bool enableConvTransposeDecompose,
119125 bool enableConvTransposeDecomposeToPhasedConv,
120- bool enableConvTranspose1dDecomposeToPhasedConv) {
126+ bool enableConvTranspose1dDecomposeToPhasedConv,
127+ bool recomposeLayernormByTranspose) {
121128 this ->recomposition = enableRecomposition;
122129 this ->quarkQuantizedOpsLegalization = enableQuarkQuantizedOpsLegalization;
123130 this ->enableConvTransposeDecompose = enableConvTransposeDecompose;
124131 this ->enableConvTransposeDecomposeToPhasedConv =
125132 enableConvTransposeDecomposeToPhasedConv;
126133 this ->enableConvTranspose1dDecomposeToPhasedConv =
127134 enableConvTranspose1dDecomposeToPhasedConv;
135+ this ->recomposeLayernormByTranspose = recomposeLayernormByTranspose;
128136 }
129137
130138 ONNXHybridTransformPass (const ONNXHybridTransformPass &pass)
@@ -171,7 +179,8 @@ struct ONNXHybridTransformPass
171179 }
172180
173181 if (recomposition) {
174- getRecomposeONNXToONNXPatterns (cumulativePatterns);
182+ getRecomposeONNXToONNXPatterns (
183+ cumulativePatterns, recomposeLayernormByTranspose);
175184 }
176185
177186 patterns = FrozenRewritePatternSet (std::move (cumulativePatterns));
@@ -210,9 +219,11 @@ std::unique_ptr<mlir::Pass> onnx_mlir::createONNXHybridTransformPass(
210219 bool enableRecomposition, bool enableQuarkQuantizedOpsLegalization,
211220 bool enableConvTransposeDecompose,
212221 bool enableConvTransposeDecomposeToPhasedConv,
213- bool enableConvTranspose1dDecomposeToPhasedConv) {
222+ bool enableConvTranspose1dDecomposeToPhasedConv,
223+ bool enableRecomposeLayernormByTranspose) {
214224 return std::make_unique<ONNXHybridTransformPass>(enableRecomposition,
215225 enableQuarkQuantizedOpsLegalization, enableConvTransposeDecompose,
216226 enableConvTransposeDecomposeToPhasedConv,
217- enableConvTranspose1dDecomposeToPhasedConv);
227+ enableConvTranspose1dDecomposeToPhasedConv,
228+ enableRecomposeLayernormByTranspose);
218229}
0 commit comments