Skip to content

Commit 71948db

Browse files
committed
Allow folding of an add into a layernorm if the bias is zero. Add missing shape check
Signed-off-by: Rickert, Jonas <[email protected]>
1 parent dbd0d83 commit 71948db

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/Dialect/ONNX/ONNXOps/Canonicalize.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1746,8 +1746,16 @@ struct PropagateBiasIntoLayerNormRewritePattern
17461746
if (!yLayerNormOp->hasOneUse())
17471747
return reportFailure("y/layer norm has too many uses");
17481748
auto lnOp = mlir::cast<OP_TYPE>(yLayerNormOp);
1749-
if (!onnx_mlir::isNoneValue(lnOp.getB()))
1749+
if (!isValueNoneOrConstZero(lnOp.getB()))
17501750
return reportFailure("layer norm already has a bias");
1751+
1752+
// Norms only support unidirectional broadcating from bias to y
1753+
const auto yType = dyn_cast<ShapedType>(y.getType());
1754+
const auto addType = dyn_cast<ShapedType>(addOp.getType());
1755+
if (!yType || !addType || !yType.hasStaticShape() ||
1756+
!addType.hasStaticShape() || yType.getShape() != addType.getShape()) {
1757+
return rewriter.notifyMatchFailure(addOp, "incompatible shapes");
1758+
}
17511759
// We are fine.
17521760
Value x = lnOp.getX();
17531761
Value scale = lnOp.getScale();

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,6 +2392,20 @@ func.func @layernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<76
23922392

23932393
// -----
23942394

2395+
func.func @layernorm_with_zero_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %bias: tensor<768xf32>) -> tensor<1x384x768xf32> {
2396+
%0 = onnx.Constant dense<0.000000e+00> : tensor<768xf32>
2397+
%NormScaled, %Mean, %InvStdDev = "onnx.LayerNormalization"(%arg0, %arg1, %0) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none)
2398+
%Y = "onnx.Add"(%bias, %NormScaled) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
2399+
return %Y : tensor<1x384x768xf32>
2400+
// CHECK-LABEL: func.func @layernorm_with_zero_bias
2401+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
2402+
// CHECK: [[Y_:%.+]], [[Mean_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.LayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, tensor<768xf32>) -> (tensor<1x384x768xf32>, none, none)
2403+
// CHECK: return [[Y_]] : tensor<1x384x768xf32>
2404+
// CHECK: }
2405+
}
2406+
2407+
// -----
2408+
23952409
// Recognize the bias and fold into RMSLayerNorm.
23962410
func.func @rmslayernorm_without_bias(%arg0: tensor<1x384x768xf32>, %arg1: tensor<768xf32>, %bias: tensor<768xf32>) -> tensor<1x384x768xf32> {
23972411
%0 = "onnx.NoValue"() {value} : () -> none

0 commit comments

Comments
 (0)