Skip to content

Commit faf797f

Browse files
authored
Merge pull request #403 from Xilinx/jrickert.rms_multi_use
Recompose a Layer/RMSNorm even if the scale multiplication has multiple uses.
2 parents 0f2391e + 9c0bf14 commit faf797f

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/Dialect/ONNX/Transforms/Recompose.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,6 @@ struct RecomposeLayerNormFromMulPattern : public OpRewritePattern<ONNXMulOp> {
373373
return reportFailure("RMS norm mul has too many uses");
374374
if (isdRecipOp && !isdRecipOp->hasOneUse())
375375
return reportFailure("RMS norm recip has too many uses");
376-
if (!nsMulOp->hasOneUse())
377-
return reportFailure("RMS norm scale mul has too many uses");
378376
// Now check values epsilon.
379377
if (!isScalarTensor(epsilon))
380378
return reportFailure("RMS epsilon is expected to be scalar");

test/mlir/onnx/onnx_recompose.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,30 @@ func.func @rms_layer_norm_v3_dyn_shape(%x: tensor<1x?x768xf32>) -> (tensor<1x?x7
511511

512512
// -----
513513

514+
// RMS Layer norm with multiple uses of the scale multiplication
515+
516+
func.func @rms_layer_norm_multi_use(%x: tensor<1x384x768xf32>, %scale: tensor<768xf32>, %bias: tensor<768xf32>) -> (tensor<1x384x768xf32>) {
517+
%eps = onnx.Constant dense<1.2E+0> : tensor<f32>
518+
%dd = "onnx.Mul"(%x, %x) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
519+
%var = "onnx.ReduceMeanV13"(%dd) {axes = [-1], keepdims = 1 : si64} : (tensor<1x384x768xf32>) -> tensor<1x384x1xf32>
520+
%varEps = "onnx.Add"(%eps, %var) : (tensor<f32>, tensor<1x384x1xf32>) -> tensor<1x384x1xf32>
521+
%StdDev = "onnx.Sqrt"(%varEps) : (tensor<1x384x1xf32>) -> tensor<1x384x1xf32>
522+
%Norm = "onnx.Div"(%x, %StdDev) : (tensor<1x384x768xf32>, tensor<1x384x1xf32>) -> tensor<1x384x768xf32>
523+
%NormScaled = "onnx.Mul"(%scale, %Norm) : (tensor<768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
524+
%MultiUse = "onnx.Add"(%NormScaled, %NormScaled) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
525+
return %MultiUse : tensor<1x384x768xf32>
526+
// mlir2FileCheck.py
527+
// CHECK-LABEL: func.func @rms_layer_norm_multi_use
528+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<768xf32>, [[PARAM_2_:%.+]]: tensor<768xf32>) -> tensor<1x384x768xf32> {
529+
// CHECK: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
530+
// CHECK: [[VAR_Y_:%.+]], [[VAR_InvStdDev_:%.+]] = "onnx.RMSLayerNormalization"([[PARAM_0_]], [[PARAM_1_]], [[VAR_0_]]) {axis = 2 : si64, epsilon = 1.200000e+00 : f32, stash_type = 1 : si64} : (tensor<1x384x768xf32>, tensor<768xf32>, none) -> (tensor<1x384x768xf32>, none)
531+
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_Y_]], [[VAR_Y_]]) : (tensor<1x384x768xf32>, tensor<1x384x768xf32>) -> tensor<1x384x768xf32>
532+
// CHECK: return [[VAR_1_]] : tensor<1x384x768xf32>
533+
// CHECK: }
534+
}
535+
536+
// -----
537+
514538
// COM: QLinearMatMul
515539
func.func @qlinear_matmul(%arg0: tensor<?x?x768xi8>, %arg1: tensor<f32>, %arg2: tensor<i8>, %arg3: tensor<768x768xi8>, %arg4: tensor<f32>, %arg5: tensor<i8>, %arg6: tensor<f32>, %arg7: tensor<i8>) -> (tensor<?x?x768xi8>) {
516540
%0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<?x?x768xi8>, tensor<f32>, tensor<i8>) -> tensor<?x?x768xf32>

0 commit comments

Comments
 (0)