Skip to content

Commit 2656a25

Browse files
committed
feat: Match LayerNorm decomposition pattern with transposing original input and axis
1 parent 2760083 commit 2656a25

File tree

10 files changed

+478
-39
lines changed

10 files changed

+478
-39
lines changed

src/Compiler/OnnxToMlirPasses.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
3232
opts.enableConvTransposeDecomposeToPhasedConv,
3333
opts.enableConvTranspose1dDecomposeToPhasedConv));
3434
if (!opts.disableRecomposeOption)
35-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass());
35+
pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass(
36+
/*target=*/"", opts.enableRecomposeLayernormByTranspose));
37+
3638
if (opts.enableONNXHybridPass) {
3739
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass(
3840
!opts.disableRecomposeOption, opts.enableQuarkQuantizedLegalization,
3941
opts.enableConvTransposeDecompose,
4042
opts.enableConvTransposeDecomposeToPhasedConv,
41-
opts.enableConvTranspose1dDecomposeToPhasedConv));
43+
opts.enableConvTranspose1dDecomposeToPhasedConv, opts.enableRecomposeLayernormByTranspose));
4244
// Convolution Optimization for CPU: enable when there are no accelerators.
4345
if (targetCPU && opts.enableConvOptPass) {
4446
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(

src/Compiler/OnnxToMlirPasses.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct OnnxToMlirOptions {
1919
bool enableRemoveDqQOp = true;
2020
bool enableRemoveDqQAroundOp = true;
2121
bool enableRemoveBinary = false;
22+
bool enableRecomposeLayernormByTranspose = false;
2223

2324
bool disableRecomposeOption = false;
2425
bool enableONNXHybridPass = true;

src/Dialect/ONNX/DialectBuilder.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,34 @@ Value OnnxBuilder::unsqueeze(Type outputType, Value data, Value axes) const {
591591
toTensor(outputType), toTensor(data), toTensor(axes));
592592
}
593593

594+
Value OnnxBuilder::upRank(
595+
mlir::Value data, int64_t toRank, bool trailing) const {
596+
assert(data && "the value doesn't exist");
597+
598+
auto tensor = mlir::cast<mlir::ShapedType>(data.getType());
599+
auto shape = getShape(tensor);
600+
auto rank = getRank(tensor);
601+
assert(rank <= toRank && "the rank of the tensor must be smaller");
602+
603+
if (rank == toRank)
604+
return data;
605+
606+
int64_t rankDiff = toRank - rank;
607+
SmallVector<int64_t> newShape;
608+
if (trailing) {
609+
newShape.append(shape.begin(), shape.end());
610+
newShape.append(SmallVector<int64_t>(rankDiff, 1));
611+
} else {
612+
newShape.resize(rankDiff, 1);
613+
newShape.append(shape.begin(), shape.end());
614+
}
615+
616+
auto newType = tensor.clone(newShape);
617+
auto shapeConst = constantInt64(newShape);
618+
auto reshaped = reshape(newType, data, shapeConst);
619+
return reshaped;
620+
}
621+
594622
Value OnnxBuilder::where(
595623
Type outputType, Value condition, Value X, Value Y) const {
596624
return createTypedOpAndInferShapes<ONNXWhereOp>(

src/Dialect/ONNX/DialectBuilder.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ struct OnnxBuilder : DialectBuilder {
246246
mlir::Value unsqueeze(
247247
mlir::Type outputType, mlir::Value data, mlir::Value axes) const;
248248

249+
// Up ranking of the data tensor with reshape operator. The trailing is the
250+
// option to choose to add the dimension with size 1 as leading or trailing.
251+
mlir::Value upRank(
252+
mlir::Value data, int64_t toRank, bool trailing = false) const;
253+
249254
// ONNXWhereOp
250255
mlir::Value where(mlir::Type outputType, mlir::Value condition, mlir::Value X,
251256
mlir::Value Y) const;

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <numeric>
2121

2222
#include "mlir/Dialect/Traits.h"
23+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2324
#include "mlir/IR/Matchers.h"
2425
#include "mlir/IR/PatternMatch.h"
2526
#include "mlir/IR/TypeUtilities.h"
@@ -584,6 +585,66 @@ class PropagateReshapeThroughBinaryOpPattern
584585
};
585586
};
586587

588+
// This pattern bubbles up AddOp through transpose to keep the bias Add
589+
// operation right after LN_type op. This will helps the other patterns fold the
590+
// add into the operands of LayerNorm.
591+
//
592+
// From: LayerNorm
593+
// |
594+
// Transpose
595+
// |
596+
// Add
597+
//
598+
// To:
599+
// LayerNorm
600+
// |
601+
// Add
602+
// |
603+
// Transpose
604+
template <typename LN_TYPE>
605+
class BubbleUpBiasForLayerNormPattern : public OpRewritePattern<ONNXAddOp> {
606+
public:
607+
using OpRewritePattern<ONNXAddOp>::OpRewritePattern;
608+
609+
LogicalResult matchAndRewrite(
610+
ONNXAddOp addOp, PatternRewriter &r) const override {
611+
if (!isConstLikeValue(addOp.getB()))
612+
return r.notifyMatchFailure(addOp, "not a constant rhs operand");
613+
614+
auto transposeOp =
615+
llvm::dyn_cast_or_null<ONNXTransposeOp>(addOp.getA().getDefiningOp());
616+
if (!transposeOp)
617+
return r.notifyMatchFailure(addOp, "the producer is not a transpose");
618+
619+
if (!transposeOp->hasOneUse())
620+
return r.notifyMatchFailure(
621+
addOp, "cannot bubble up because transpose has other user");
622+
623+
auto layernormResult = transposeOp.getData();
624+
auto layerNorm =
625+
llvm::dyn_cast_or_null<LN_TYPE>(layernormResult.getDefiningOp());
626+
if (!layerNorm)
627+
return r.notifyMatchFailure(
628+
transposeOp, "the producer is not a layernorm");
629+
630+
if (!isNoneValue(layerNorm.getB()))
631+
return r.notifyMatchFailure(layerNorm, "layernorm already has a bias");
632+
633+
OnnxBuilder create(r, addOp.getLoc());
634+
635+
auto perm = extractFromIntegerArrayAttr<int64_t>(transposeOp.getPermAttr());
636+
auto invertedPerm = invertPermutationVector(perm);
637+
auto cstReshaped = create.upRank(addOp.getB(), getRank(addOp.getType()));
638+
auto cstTranposed = create.transposeInt64(cstReshaped, invertedPerm);
639+
auto newAddOp = create.add(layernormResult, cstTranposed);
640+
auto transposedBack = create.transposeInt64(newAddOp, perm);
641+
642+
r.replaceOp(addOp, transposedBack);
643+
644+
return success();
645+
};
646+
};
647+
587648
// This rewriting is to optimize the scalar Div/Mul in self-attention layers.
588649
// In particular, it rewrites the following pattern:
589650
// ```
@@ -2426,6 +2487,10 @@ void ONNXAddOp::getCanonicalizationPatterns(
24262487
PropagateBiasIntoLayerNormRewritePattern<ONNXRMSLayerNormalizationOp>>(
24272488
context);
24282489
results.insert<PropagateReshapeThroughBinaryOpPattern<ONNXAddOp>>(context);
2490+
results.insert<BubbleUpBiasForLayerNormPattern<ONNXLayerNormalizationOp>>(
2491+
context);
2492+
results.insert<BubbleUpBiasForLayerNormPattern<ONNXRMSLayerNormalizationOp>>(
2493+
context);
24292494
}
24302495

24312496
/// on the ONNXAndOp.

src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,28 @@ struct ONNXHybridTransformPass
111111
"phased Conv"),
112112
::llvm::cl::init(false)};
113113

114+
Option<bool> recomposeLayernormByTranspose{*this,
115+
"recompose-layernorm-by-transpose",
116+
llvm::cl::desc("Use transpose operator to make unsuitable axes suitable "
117+
"for matching layernorm"),
118+
::llvm::cl::init(false)};
119+
114120
FrozenRewritePatternSet patterns;
115121

116122
ONNXHybridTransformPass(bool enableRecomposition,
117123
bool enableQuarkQuantizedOpsLegalization,
118124
bool enableConvTransposeDecompose,
119125
bool enableConvTransposeDecomposeToPhasedConv,
120-
bool enableConvTranspose1dDecomposeToPhasedConv) {
126+
bool enableConvTranspose1dDecomposeToPhasedConv,
127+
bool recomposeLayernormByTranspose) {
121128
this->recomposition = enableRecomposition;
122129
this->quarkQuantizedOpsLegalization = enableQuarkQuantizedOpsLegalization;
123130
this->enableConvTransposeDecompose = enableConvTransposeDecompose;
124131
this->enableConvTransposeDecomposeToPhasedConv =
125132
enableConvTransposeDecomposeToPhasedConv;
126133
this->enableConvTranspose1dDecomposeToPhasedConv =
127134
enableConvTranspose1dDecomposeToPhasedConv;
135+
this->recomposeLayernormByTranspose = recomposeLayernormByTranspose;
128136
}
129137

130138
ONNXHybridTransformPass(const ONNXHybridTransformPass &pass)
@@ -171,7 +179,8 @@ struct ONNXHybridTransformPass
171179
}
172180

173181
if (recomposition) {
174-
getRecomposeONNXToONNXPatterns(cumulativePatterns);
182+
getRecomposeONNXToONNXPatterns(
183+
cumulativePatterns, recomposeLayernormByTranspose);
175184
}
176185

177186
patterns = FrozenRewritePatternSet(std::move(cumulativePatterns));
@@ -210,9 +219,11 @@ std::unique_ptr<mlir::Pass> onnx_mlir::createONNXHybridTransformPass(
210219
bool enableRecomposition, bool enableQuarkQuantizedOpsLegalization,
211220
bool enableConvTransposeDecompose,
212221
bool enableConvTransposeDecomposeToPhasedConv,
213-
bool enableConvTranspose1dDecomposeToPhasedConv) {
222+
bool enableConvTranspose1dDecomposeToPhasedConv,
223+
bool enableRecomposeLayernormByTranspose) {
214224
return std::make_unique<ONNXHybridTransformPass>(enableRecomposition,
215225
enableQuarkQuantizedOpsLegalization, enableConvTransposeDecompose,
216226
enableConvTransposeDecomposeToPhasedConv,
217-
enableConvTranspose1dDecomposeToPhasedConv);
227+
enableConvTranspose1dDecomposeToPhasedConv,
228+
enableRecomposeLayernormByTranspose);
218229
}

0 commit comments

Comments
 (0)