Skip to content

Commit dbd0d83

Browse files
committed
Add canonicalization pattern to move mul into (RMS)Layernorm
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent a364c43 commit dbd0d83

File tree

2 files changed

+176
-0
lines changed

2 files changed

+176
-0
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,89 @@ struct RecomposeConcatPattern : public OpRewritePattern<ONNXConcatOp> {
16331633
// =============================================================================
16341634
// Rewrite pattern LayerNormalization
16351635
// =============================================================================
1636+
namespace {
1637+
bool isValueNoneOrConstZero(Value value) {
1638+
if (!value) {
1639+
return false;
1640+
}
1641+
if (isNoneValue(value)) {
1642+
return true;
1643+
}
1644+
auto elementsAttr = getElementAttributeFromONNXValue(value);
1645+
if (!elementsAttr) {
1646+
return false;
1647+
}
1648+
if (!elementsAttr.isSplat()) {
1649+
return false;
1650+
}
1651+
if (!elementsAttr.template getSplatValue<APFloat>().isZero()) {
1652+
return false;
1653+
}
1654+
return true;
1655+
}
1656+
} // namespace
1657+
1658+
template <typename OP_TYPE>
1659+
struct PropagateScaleIntoLayerNormPattern : public OpRewritePattern<ONNXMulOp> {
1660+
using OpRewritePattern<ONNXMulOp>::OpRewritePattern;
1661+
1662+
LogicalResult matchAndRewrite(
1663+
ONNXMulOp mulOp, PatternRewriter &rewriter) const final {
1664+
using namespace onnx_mlir;
1665+
Value y;
1666+
Value mulScale;
1667+
Operation *yLayerNormOp;
1668+
// Match
1669+
// %neutral = "onnx.Constant" {1.0}
1670+
// %y, %mean, %invStdDev = "onnx.LayerNormalization"(%x, %neutral, %noBias)
1671+
// %yScale = "onnx.Mul"(%y, %mulScale)
1672+
if (!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>(
1673+
yLayerNormOp, mulOp, y, mulScale, 0) &&
1674+
!onnx_mlir::operandOfOpDefinedBy<OP_TYPE>(
1675+
yLayerNormOp, mulOp, mulScale, y, 1)) {
1676+
return rewriter.notifyMatchFailure(mulOp, "missing y, layer norm op");
1677+
}
1678+
if (!yLayerNormOp->hasOneUse()) {
1679+
return rewriter.notifyMatchFailure(
1680+
mulOp, "y/layer norm has too many uses");
1681+
}
1682+
OP_TYPE normOp = cast<OP_TYPE>(yLayerNormOp);
1683+
// Bias needs to be zero
1684+
if (!isValueNoneOrConstZero(normOp.getB())) {
1685+
return rewriter.notifyMatchFailure(
1686+
mulOp, "layer norm already has a bias");
1687+
}
1688+
1689+
auto existingScale = normOp.getScale();
1690+
auto elementsAttr = getElementAttributeFromONNXValue(existingScale);
1691+
if (!elementsAttr) {
1692+
return rewriter.notifyMatchFailure(
1693+
mulOp, "missing elements attribute or scale is not const");
1694+
}
1695+
if (!elementsAttr.isSplat()) {
1696+
return rewriter.notifyMatchFailure(mulOp, "scale is not a splat value");
1697+
}
1698+
if (!elementsAttr.template getSplatValue<APFloat>().isExactlyValue(1.0)) {
1699+
return rewriter.notifyMatchFailure(mulOp, "scale is not 1.0");
1700+
}
1701+
// Norms only support unidirectional broadcating from scale to y
1702+
const auto yType = dyn_cast<ShapedType>(y.getType());
1703+
const auto mulType = dyn_cast<ShapedType>(mulOp.getType());
1704+
if (!yType || !mulType || !yType.hasStaticShape() ||
1705+
!mulType.hasStaticShape() || yType.getShape() != mulType.getShape()) {
1706+
return rewriter.notifyMatchFailure(mulOp, "incompatible shapes");
1707+
}
1708+
1709+
rewriter.moveOpAfter(
1710+
normOp, mulOp); // Make sure we can use the const of the mul
1711+
rewriter.modifyOpInPlace(normOp, [&] {
1712+
normOp.setOperand(/*scale*/ 1, mulScale);
1713+
normOp->setLoc(rewriter.getFusedLoc({normOp.getLoc(), mulOp->getLoc()}));
1714+
});
1715+
rewriter.replaceOp(mulOp, normOp.getY());
1716+
return success();
1717+
}
1718+
};
16361719

16371720
template <typename OP_TYPE>
16381721
struct PropagateBiasIntoLayerNormRewritePattern
@@ -2189,6 +2272,11 @@ void ONNXAddOp::getCanonicalizationPatterns(
21892272
results.insert<FuseAddConvNullBiasPattern>(context);
21902273
results.insert<BinaryOpBroadcastAxisPattern<ONNXAddOp>>(context);
21912274
results.insert<PropagateScalarConstantExpandPattern<ONNXAddOp>>(context);
2275+
results.insert<PropagateScaleIntoLayerNormPattern<ONNXLayerNormalizationOp>>(
2276+
context);
2277+
results
2278+
.insert<PropagateScaleIntoLayerNormPattern<ONNXRMSLayerNormalizationOp>>(
2279+
context);
21922280
results.insert<
21932281
PropagateBiasIntoLayerNormRewritePattern<ONNXLayerNormalizationOp>>(
21942282
context);

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,4 +2405,92 @@ func.func @rmslayernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor
24052405
// CHECK: }
24062406
}
24072407

2408+
// -----
2409+
2410+
// Recognize the scale and fold into LayerNorm.
2411+
func.func @layernorm_with_neutral_scale(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %mulVal: tensor<768xf32>) -> tensor<1x384x768xf32> {
2412+
%0 = "onnx.NoValue"() {value} : () -> none
2413+
%1 = onnx.Constant dense<1.000000e+00> : tensor<768xf32>
2414+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2415+
%Y = "onnx.Mul"(%mulVal, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2416+
return %Y : tensor<1x384x768xf32>
2417+
// CHECK-LABEL: func.func @layernorm_with_neutral_scale
2418+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2419+
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
2420+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_2_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2421+
// CHECK: return [[VAR_Y_]] : tensor<1x384x768xf32>
2422+
// CHECK: }
2423+
}
2424+
2425+
2426+
// -----
2427+
2428+
func.func @layernorm_scale_not_one(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %mulVal: tensor<768xf32>) -> tensor<1x384x768xf32> {
2429+
%0 = "onnx.NoValue"() {value} : () -> none
2430+
%1 = onnx.Constant dense<1.100000e+00> : tensor<768xf32>
2431+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2432+
%Y = "onnx.Mul"(%mulVal, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2433+
return %Y : tensor<1x384x768xf32>
2434+
// CHECK-LABEL: func.func @layernorm_scale_not_one
2435+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2436+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
2437+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.100000e+00> : tensor<768xf32>
2438+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2439+
// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_2_]], [[VAR_Y_]]) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2440+
// CHECK: return [[VAR_2_]] : tensor<1x384x768xf32>
2441+
// CHECK: }
2442+
}
2443+
2444+
// -----
2445+
2446+
func.func @layernorm_bias_not_zero(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %mulVal: tensor<768xf32>) -> tensor<1x384x768xf32> {
2447+
%0 = onnx.Constant dense<2.000000e+00> : tensor<768xf32>
2448+
%1 = onnx.Constant dense<1.000000e+00> : tensor<768xf32>
2449+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none)
2450+
%Y = "onnx.Mul"(%mulVal, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2451+
return %Y : tensor<1x384x768xf32>
2452+
// CHECK-LABEL: func.func @layernorm_bias_not_zero
2453+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2454+
// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<2.000000e+00> : tensor<768xf32>
2455+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<768xf32>
2456+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none)
2457+
// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_2_]], [[VAR_Y_]]) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2458+
// CHECK: return [[VAR_2_]] : tensor<1x384x768xf32>
2459+
// CHECK: }
2460+
}
2461+
2462+
// -----
2463+
2464+
func.func @layernorm_broadcast(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %mulVal: tensor<10x384x768xf32>) -> tensor<10x384x768xf32> {
2465+
%0 = "onnx.NoValue"() {value} : () -> none
2466+
%1 = onnx.Constant dense<1.000000e+00> : tensor<768xf32>
2467+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2468+
%Y = "onnx.Mul"(%mulVal, %NormScaled) : (tensor<10x384x768xf32>, tensor<1x384x768xf32>) -> tensor<10x384x768xf32>
2469+
return %Y : tensor<10x384x768xf32>
2470+
// CHECK-LABEL: func.func @layernorm_broadcast
2471+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<10x384x768xf32>) -> tensor<10x384x768xf32> {
2472+
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
2473+
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<768xf32>
2474+
// CHECK: [[VAR_Y_:%.+]], [[VAR_Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[VAR_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none, none)
2475+
// CHECK: [[VAR_2_:%.+]] = "onnx.Mul"([[PARAM_2_]], [[VAR_Y_]]) : (tensor<10x384x768xf32>, tensor<1x384x768xf32>) -> tensor<10x384x768xf32>
2476+
// CHECK: return [[VAR_2_]] : tensor<10x384x768xf32>
2477+
// CHECK: }
2478+
}
2479+
2480+
// -----
2481+
2482+
// Recognize the scale and fold into the RMSNorm.
2483+
func.func @rmslayernorm_with_neutral_scale(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %mulVal: tensor<768xf32>) -> tensor<1x384x768xf32> {
2484+
%0 = "onnx.NoValue"() {value} : () -> none
2485+
%1 = onnx.Constant dense<1.000000e+00> : tensor<768xf32>
2486+
%NormScaled, %InvStdDev = "onnx.RMSLayerNormalization"(%arg0, %1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none)
2487+
%Y = "onnx.Mul"(%mulVal, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2488+
return %Y : tensor<1x384x768xf32>
2489+
// CHECK-LABEL: func.func @rmslayernorm_with_neutral_scale
2490+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2491+
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
2492+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_2_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none)
2493+
// CHECK: return [[VAR_Y_]] : tensor<1x384x768xf32>
2494+
// CHECK: }
2495+
}
24082496

0 commit comments

Comments
 (0)