Skip to content

Commit 4e5bc98

Browse files
authored
Merge pull request #462 from ehsan-toosi/enadjara.layernorm_with_transpose_main
Feat: Match LayerNorm decomposition pattern with transposing original input and axis
2 parents 2760083 + fb69354 commit 4e5bc98

File tree

11 files changed

+612
-41
lines changed

11 files changed

+612
-41
lines changed

src/Compiler/OnnxToMlirPasses.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
3232
opts.enableConvTransposeDecomposeToPhasedConv,
3333
opts.enableConvTranspose1dDecomposeToPhasedConv));
3434
if (!opts.disableRecomposeOption)
35-
pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass());
35+
pm.addNestedPass<func::FuncOp>(onnx_mlir::createRecomposeONNXToONNXPass(
36+
/*target=*/"", opts.enableRecomposeLayernormByTranspose));
37+
3638
if (opts.enableONNXHybridPass) {
3739
pm.addNestedPass<func::FuncOp>(onnx_mlir::createONNXHybridTransformPass(
3840
!opts.disableRecomposeOption, opts.enableQuarkQuantizedLegalization,
3941
opts.enableConvTransposeDecompose,
4042
opts.enableConvTransposeDecomposeToPhasedConv,
41-
opts.enableConvTranspose1dDecomposeToPhasedConv));
43+
opts.enableConvTranspose1dDecomposeToPhasedConv,
44+
opts.enableRecomposeLayernormByTranspose));
4245
// Convolution Optimization for CPU: enable when there are no accelerators.
4346
if (targetCPU && opts.enableConvOptPass) {
4447
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConvOptONNXToONNXPass(
@@ -48,7 +51,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
4851
/*enableQuarkQuantizedOpsLegalization=*/false,
4952
opts.enableConvTransposeDecompose,
5053
opts.enableConvTransposeDecomposeToPhasedConv,
51-
opts.enableConvTranspose1dDecomposeToPhasedConv));
54+
opts.enableConvTranspose1dDecomposeToPhasedConv,
55+
opts.enableRecomposeLayernormByTranspose));
5256
}
5357
} else {
5458
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
@@ -104,7 +108,8 @@ void addONNXToMLIRPasses(mlir::PassManager &pm, bool targetCPU,
104108
!opts.disableRecomposeOption, opts.enableQuarkQuantizedLegalization,
105109
opts.enableConvTransposeDecompose,
106110
opts.enableConvTransposeDecomposeToPhasedConv,
107-
opts.enableConvTranspose1dDecomposeToPhasedConv));
111+
opts.enableConvTranspose1dDecomposeToPhasedConv,
112+
opts.enableRecomposeLayernormByTranspose));
108113
} else {
109114
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
110115
pm.addPass(mlir::createCanonicalizerPass());

src/Compiler/OnnxToMlirPasses.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ struct OnnxToMlirOptions {
1919
bool enableRemoveDqQOp = true;
2020
bool enableRemoveDqQAroundOp = true;
2121
bool enableRemoveBinary = false;
22+
bool enableRecomposeLayernormByTranspose = false;
2223

2324
bool disableRecomposeOption = false;
2425
bool enableONNXHybridPass = true;

src/Dialect/ONNX/DialectBuilder.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,34 @@ Value OnnxBuilder::unsqueeze(Type outputType, Value data, Value axes) const {
591591
toTensor(outputType), toTensor(data), toTensor(axes));
592592
}
593593

594+
Value OnnxBuilder::upRank(
595+
mlir::Value data, int64_t toRank, bool trailing) const {
596+
assert(data && "the value doesn't exist");
597+
598+
auto tensor = mlir::cast<mlir::ShapedType>(data.getType());
599+
auto shape = getShape(tensor);
600+
auto rank = getRank(tensor);
601+
assert(rank <= toRank && "the rank of the tensor must be smaller");
602+
603+
if (rank == toRank)
604+
return data;
605+
606+
int64_t rankDiff = toRank - rank;
607+
SmallVector<int64_t> newShape;
608+
if (trailing) {
609+
newShape.append(shape.begin(), shape.end());
610+
newShape.append(SmallVector<int64_t>(rankDiff, 1));
611+
} else {
612+
newShape.resize(rankDiff, 1);
613+
newShape.append(shape.begin(), shape.end());
614+
}
615+
616+
auto newType = tensor.clone(newShape);
617+
auto shapeConst = constantInt64(newShape);
618+
auto reshaped = reshape(newType, data, shapeConst);
619+
return reshaped;
620+
}
621+
594622
Value OnnxBuilder::where(
595623
Type outputType, Value condition, Value X, Value Y) const {
596624
return createTypedOpAndInferShapes<ONNXWhereOp>(

src/Dialect/ONNX/DialectBuilder.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,11 @@ struct OnnxBuilder : DialectBuilder {
246246
mlir::Value unsqueeze(
247247
mlir::Type outputType, mlir::Value data, mlir::Value axes) const;
248248

249+
// Up ranking of the data tensor with reshape operator. The trailing is the
250+
// option to choose to add the dimension with size 1 as leading or trailing.
251+
mlir::Value upRank(
252+
mlir::Value data, int64_t toRank, bool trailing = false) const;
253+
249254
// ONNXWhereOp
250255
mlir::Value where(mlir::Type outputType, mlir::Value condition, mlir::Value X,
251256
mlir::Value Y) const;

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <numeric>
2121

2222
#include "mlir/Dialect/Traits.h"
23+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2324
#include "mlir/IR/Matchers.h"
2425
#include "mlir/IR/PatternMatch.h"
2526
#include "mlir/IR/TypeUtilities.h"
@@ -584,6 +585,67 @@ class PropagateReshapeThroughBinaryOpPattern
584585
};
585586
};
586587

588+
// This pattern bubbles up AddOp through transpose to keep the bias Add
589+
// operation right after LN_type op. This will helps the other patterns fold the
590+
// add into the operands of a Norm operator.
591+
//
592+
// From:
593+
// Norm operator
594+
// |
595+
// Transpose
596+
// |
597+
// Add
598+
//
599+
// To:
600+
// Norm operator
601+
// |
602+
// Add
603+
// |
604+
// Transpose
605+
template <typename LN_TYPE>
606+
class BubbleUpBiasForNormOpPattern : public OpRewritePattern<ONNXAddOp> {
607+
public:
608+
using OpRewritePattern<ONNXAddOp>::OpRewritePattern;
609+
610+
LogicalResult matchAndRewrite(
611+
ONNXAddOp addOp, PatternRewriter &r) const override {
612+
if (!isConstLikeValue(addOp.getB()))
613+
return r.notifyMatchFailure(addOp, "not a constant rhs operand");
614+
615+
auto transposeOp =
616+
llvm::dyn_cast_or_null<ONNXTransposeOp>(addOp.getA().getDefiningOp());
617+
if (!transposeOp)
618+
return r.notifyMatchFailure(addOp, "the producer is not a transpose");
619+
620+
if (!transposeOp->hasOneUse())
621+
return r.notifyMatchFailure(
622+
addOp, "cannot bubble up because transpose has other user");
623+
624+
auto layernormResult = transposeOp.getData();
625+
auto layerNorm =
626+
llvm::dyn_cast_or_null<LN_TYPE>(layernormResult.getDefiningOp());
627+
if (!layerNorm)
628+
return r.notifyMatchFailure(
629+
transposeOp, "the producer is not a layernorm");
630+
631+
if (!isNoneValue(layerNorm.getB()))
632+
return r.notifyMatchFailure(layerNorm, "layernorm already has a bias");
633+
634+
OnnxBuilder create(r, addOp.getLoc());
635+
636+
auto perm = extractFromIntegerArrayAttr<int64_t>(transposeOp.getPermAttr());
637+
auto invertedPerm = invertPermutationVector(perm);
638+
auto cstReshaped = create.upRank(addOp.getB(), getRank(addOp.getType()));
639+
auto cstTranposed = create.transposeInt64(cstReshaped, invertedPerm);
640+
auto newAddOp = create.add(layernormResult, cstTranposed);
641+
auto transposedBack = create.transposeInt64(newAddOp, perm);
642+
643+
r.replaceOp(addOp, transposedBack);
644+
645+
return success();
646+
};
647+
};
648+
587649
// This rewriting is to optimize the scalar Div/Mul in self-attention layers.
588650
// In particular, it rewrites the following pattern:
589651
// ```
@@ -2426,6 +2488,10 @@ void ONNXAddOp::getCanonicalizationPatterns(
24262488
PropagateBiasIntoLayerNormRewritePattern<ONNXRMSLayerNormalizationOp>>(
24272489
context);
24282490
results.insert<PropagateReshapeThroughBinaryOpPattern<ONNXAddOp>>(context);
2491+
results.insert<BubbleUpBiasForNormOpPattern<ONNXLayerNormalizationOp>>(
2492+
context);
2493+
results.insert<BubbleUpBiasForNormOpPattern<ONNXRMSLayerNormalizationOp>>(
2494+
context);
24292495
}
24302496

24312497
/// on the ONNXAndOp.

src/Dialect/ONNX/Transforms/ONNXHybridTransformPass.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,28 @@ struct ONNXHybridTransformPass
111111
"phased Conv"),
112112
::llvm::cl::init(false)};
113113

114+
Option<bool> recomposeLayernormByTranspose{*this,
115+
"recompose-layernorm-by-transpose",
116+
llvm::cl::desc("Use transpose operator to make unsuitable axes suitable "
117+
"for matching layernorm"),
118+
::llvm::cl::init(false)};
119+
114120
FrozenRewritePatternSet patterns;
115121

116122
ONNXHybridTransformPass(bool enableRecomposition,
117123
bool enableQuarkQuantizedOpsLegalization,
118124
bool enableConvTransposeDecompose,
119125
bool enableConvTransposeDecomposeToPhasedConv,
120-
bool enableConvTranspose1dDecomposeToPhasedConv) {
126+
bool enableConvTranspose1dDecomposeToPhasedConv,
127+
bool recomposeLayernormByTranspose) {
121128
this->recomposition = enableRecomposition;
122129
this->quarkQuantizedOpsLegalization = enableQuarkQuantizedOpsLegalization;
123130
this->enableConvTransposeDecompose = enableConvTransposeDecompose;
124131
this->enableConvTransposeDecomposeToPhasedConv =
125132
enableConvTransposeDecomposeToPhasedConv;
126133
this->enableConvTranspose1dDecomposeToPhasedConv =
127134
enableConvTranspose1dDecomposeToPhasedConv;
135+
this->recomposeLayernormByTranspose = recomposeLayernormByTranspose;
128136
}
129137

130138
ONNXHybridTransformPass(const ONNXHybridTransformPass &pass)
@@ -171,7 +179,8 @@ struct ONNXHybridTransformPass
171179
}
172180

173181
if (recomposition) {
174-
getRecomposeONNXToONNXPatterns(cumulativePatterns);
182+
getRecomposeONNXToONNXPatterns(
183+
cumulativePatterns, recomposeLayernormByTranspose);
175184
}
176185

177186
patterns = FrozenRewritePatternSet(std::move(cumulativePatterns));
@@ -210,9 +219,11 @@ std::unique_ptr<mlir::Pass> onnx_mlir::createONNXHybridTransformPass(
210219
bool enableRecomposition, bool enableQuarkQuantizedOpsLegalization,
211220
bool enableConvTransposeDecompose,
212221
bool enableConvTransposeDecomposeToPhasedConv,
213-
bool enableConvTranspose1dDecomposeToPhasedConv) {
222+
bool enableConvTranspose1dDecomposeToPhasedConv,
223+
bool enableRecomposeLayernormByTranspose) {
214224
return std::make_unique<ONNXHybridTransformPass>(enableRecomposition,
215225
enableQuarkQuantizedOpsLegalization, enableConvTransposeDecompose,
216226
enableConvTransposeDecomposeToPhasedConv,
217-
enableConvTranspose1dDecomposeToPhasedConv);
227+
enableConvTranspose1dDecomposeToPhasedConv,
228+
enableRecomposeLayernormByTranspose);
218229
}

0 commit comments

Comments
 (0)