From 0c9ff3cf627fab75181a38e1d0235f07055d6658 Mon Sep 17 00:00:00 2001 From: Arkar-Hema Date: Fri, 18 Apr 2025 05:30:10 -0400 Subject: [PATCH 1/4] Backward fold scale axis to gemm layer Signed-off-by: Arkar-Hema --- src/Dialect/ONNX/ONNXOps/Canonicalize.cpp | 9 +++++ src/Dialect/ONNX/ONNXOps/Canonicalize.td | 48 +++++++++++++++++++++++ test/mlir/onnx/onnx_canonicalization.mlir | 26 ++++++++++++ 3 files changed, 83 insertions(+) diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index 62957db7af..f65cc300c2 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -90,6 +90,14 @@ bool isNotConvProducer(mlir::Value val) { return true; // If no defining op, assume it's safe } +bool isTransBFalse(mlir::Attribute attr) { + if (auto intAttr = attr.dyn_cast()) { + int64_t val = intAttr.getValue().getSExtValue(); // safe for signless integers + return val == 0; // return true if transB is false (0) + } + return false; // default fallback +} + // Get the index of the axis value in the given permutation array. IntegerAttr getIndexOfAxisInPerm( PatternRewriter &rewriter, ArrayAttr permAttr, IntegerAttr axis) { @@ -1685,6 +1693,7 @@ void ONNXBatchNormalizationInferenceModeOp::getCanonicalizationPatterns( results.insert(context); results.insert(context); results.insert(context); + results.insert(context); } /// on the ONNXAddOp. diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.td b/src/Dialect/ONNX/ONNXOps/Canonicalize.td index 5f4b05273b..799eb2b749 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.td +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.td @@ -925,6 +925,54 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat< [(HasRankOf<1> $x)], [], (addBenefit 0) >; +//===----------------------------------------------------------------------===// +// This is to fold the composition: 'BatchNormalization o Gemm' into 'Gemm' +// by deriving new 'B' and 'C' for 'Gemm' operation after fusing scale and bias. +// +// Given: +// (Gemm) Z = A * B + C +// (BatchNorm) Y = scale * (Z - mean) / sqrt(var + eps) + bias +// +// This transformation corresponds to a recomposition: +// Y = A * (scale * B) + (scale * bias + C) +// +// Therefore, we rewrite: +// onnx.BatchNormalizationInferenceMode( +// onnx.Gemm(A, B, C, alpha, beta, transA, transB), +// scale, bias, mean, var +// ) {epsilon = ..., momentum = ...} +// +// as: +// onnx.Gemm( +// A, +// onnx.Mul(B, scale), +// onnx.Add(onnx.Mul(bias, scale), C), +// alpha, beta, transA, transB) +// +// This transformation is only valid when transB = 0 +// to maintain the correct computation shape alignment. +// +//===----------------------------------------------------------------------===// + +def isTransBFalse : Constraint, "TransB is 1 not 0" +>; + +def BackwardFoldScaleAxisToGemmPattern : Pat< + (ONNXBatchNormalizationInferenceModeOp:$res + (ONNXGemmOp $A, $B, $C, $alpha, $beta, $transA, $transB), + $scale, $bias, $_mean, $_var, $_epsilon, $_momentum), + (ONNXGemmOp + $A, + (ONNXMulOp $B, $scale), + (ONNXAddOp + (ONNXMulOp $bias, $scale), + $C), + (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)), + [(isTransBFalse $transB)], + [], (addBenefit 1) +>; + //===----------------------------------------------------------------------===// // Canonicalization for ONNXShapeOp //===----------------------------------------------------------------------===// diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 98f2ad5adb..60dc47d0a8 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -784,6 +784,32 @@ func.func @test_rewrite_batchnormtestmode_1d_f16(%arg0 : tensor<64xf16>, %scale // ----- +func.func @test_backward_fold_scale_axis(%arg0: tensor<1x256xf32>) -> tensor<1x128xf32> { + %0 = onnx.Constant dense<0.00999999977> : tensor<256x128xf32> + %1 = onnx.Constant dense<0.00999999977> : tensor<128xf32> + %2 = onnx.Constant dense<0.00999999977> : tensor<128xf32> + %3 = onnx.Constant dense<0.00999999977> : tensor<128xf32> + %4 = onnx.Constant dense<0.00999999977> : tensor<128xf32> + %5 = onnx.Constant dense<0.00999999977> : tensor<128xf32> + %6 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "onnx.Gemm_0", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32> + %7 = "onnx.BatchNormalizationInferenceMode"(%6, %2, %3, %4, %5) {epsilon = 9.99999974E-6 : f32, momentum = 0.899999976 : f32, onnx_node_name = "onnx.BatchNormalizationInferenceMode_1"} : (tensor<1x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor<1x128xf32> + %8 = "onnx.Relu"(%7) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32> + return %8 : tensor<1x128xf32> + // CHECK-LABEL: func @test_backward_fold_scale_axis + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x256xf32>) -> tensor<1x128xf32> { + // CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<256x128xf32> + // CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<128xf32> + // CHECK: [[MUL_0_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_1_]]) + // CHECK: [[MUL_1_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]]) + // CHECK: [[ADD_0_:%.+]] = "onnx.Add"([[MUL_1_]], [[VAR_1_]]) + // CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[MUL_0_]], [[ADD_0_]]) + // CHECK-SAME: : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32> + // CHECK: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_2_]]) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32> + // CHECK-NEXT: return [[VAR_3_]] : tensor<1x128xf32> +} + +// ----- + func.func @test_normalize_add(%arg0 : tensor<2xf32>) -> tensor<2xf32> { %cst = "onnx.NoValue"() {value} : () -> none %0 = onnx.Constant dense<[0.0, 1.0]> : tensor<2xf32> From f672ddfe66d76a3d6f856c86cf471cbcb6b07707 Mon Sep 17 00:00:00 2001 From: Arkar-Hema Date: Fri, 18 Apr 2025 05:41:01 -0400 Subject: [PATCH 2/4] Clang format fix Signed-off-by: Arkar-Hema --- src/Dialect/ONNX/ONNXOps/Canonicalize.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index f65cc300c2..13ef046deb 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -92,8 +92,9 @@ bool isNotConvProducer(mlir::Value val) { bool isTransBFalse(mlir::Attribute attr) { if (auto intAttr = attr.dyn_cast()) { - int64_t val = intAttr.getValue().getSExtValue(); // safe for signless integers - return val == 0; // return true if transB is false (0) + int64_t val = + intAttr.getValue().getSExtValue(); // safe for signless integers + return val == 0; // return true if transB is false (0) } return false; // default fallback } From 142687652059034ed9e5581e2761a4f32e65ae72 Mon Sep 17 00:00:00 2001 From: Arkar-Hema Date: Thu, 24 Apr 2025 23:47:29 -0400 Subject: [PATCH 3/4] Backward fold batch to gemm Signed-off-by: Arkar-Hema --- src/Dialect/ONNX/ONNXOps/Canonicalize.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.td b/src/Dialect/ONNX/ONNXOps/Canonicalize.td index 799eb2b749..2b2b82ed17 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.td +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.td @@ -934,7 +934,7 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat< // (BatchNorm) Y = scale * (Z - mean) / sqrt(var + eps) + bias // // This transformation corresponds to a recomposition: -// Y = A * (scale * B) + (scale * bias + C) +// Y = A * (scale * B) + (scale * C + bias) // // Therefore, we rewrite: // onnx.BatchNormalizationInferenceMode( @@ -966,8 +966,8 @@ def BackwardFoldScaleAxisToGemmPattern : Pat< $A, (ONNXMulOp $B, $scale), (ONNXAddOp - (ONNXMulOp $bias, $scale), - $C), + (ONNXMulOp $C, $scale), + $bias), (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)), [(isTransBFalse $transB)], [], (addBenefit 1) From 9619505470d0beb1618eda53e5f1ee495a6b7971 Mon Sep 17 00:00:00 2001 From: Arkar-Hema Date: Thu, 15 May 2025 00:42:08 -0400 Subject: [PATCH 4/4] Added mean, var and eps contraints Signed-off-by: Arkar-Hema --- src/Dialect/ONNX/ONNXOps/Canonicalize.cpp | 30 ++++++++++++++++ src/Dialect/ONNX/ONNXOps/Canonicalize.td | 44 ++++++++++++++++++----- test/mlir/onnx/onnx_canonicalization.mlir | 6 ++-- 3 files changed, 68 insertions(+), 12 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index b2e9008991..48ea1e7a98 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -99,6 +99,36 @@ bool isTransBFalse(mlir::Attribute attr) { return false; // default fallback } +bool isZeroTensorOrSplat(Value val) { + if (auto constOp = val.getDefiningOp()) { + auto attrOpt = constOp.getValue(); + if (attrOpt.has_value()) { + if (auto dense = mlir::dyn_cast(*attrOpt)) + return dense.isSplat() && dense.getSplatValue().isZero(); + } + } + return false; +} + +bool isOneTensorOrSplat(Value val) { + if (auto constOp = val.getDefiningOp()) { + auto attrOpt = constOp.getValue(); + if (attrOpt.has_value()) { + if (auto dense = mlir::dyn_cast(*attrOpt)) { + if (dense.isSplat()) + return dense.getSplatValue().convertToDouble() == 1.0; + } + } + } + return false; +} + +bool isZeroAttrOrZeroTensor(Attribute attr) { + if (auto floatAttr = mlir::dyn_cast(attr)) + return floatAttr.getValue().isZero(); + return false; +} + // Get the index of the axis value in the given permutation array. IntegerAttr getIndexOfAxisInPerm( PatternRewriter &rewriter, ArrayAttr permAttr, IntegerAttr axis) { diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.td b/src/Dialect/ONNX/ONNXOps/Canonicalize.td index fc15020243..998f73f42a 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.td +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.td @@ -926,15 +926,21 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat< >; //===----------------------------------------------------------------------===// -// This is to fold the composition: 'BatchNormalization o Gemm' into 'Gemm' -// by deriving new 'B' and 'C' for 'Gemm' operation after fusing scale and bias. +// This optimization folds the composition: 'BatchNormalization o Gemm' into 'Gemm' +// by recomputing new 'B' and 'C' parameters for the Gemm operation by fusing +// the BatchNormalization's scale and bias directly into them. // // Given: // (Gemm) Z = A * B + C -// (BatchNorm) Y = scale * (Z - mean) / sqrt(var + eps) + bias +// (BatchNormalization in inference mode) +// Y = scale * (Z - mean) / sqrt(var + epsilon) + bias // -// This transformation corresponds to a recomposition: -// Y = A * (scale * B) + (scale * C + bias) +// In inference mode, when mean=0, var=1, and epsilon=0, the BatchNormalization +// simplifies to: +// Y = scale * Z + bias +// +// This allows us to recompute: +// Y = A * (scale * B) + (scale * C + bias) // // Therefore, we rewrite: // onnx.BatchNormalizationInferenceMode( @@ -946,11 +952,14 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat< // onnx.Gemm( // A, // onnx.Mul(B, scale), -// onnx.Add(onnx.Mul(bias, scale), C), +// onnx.Add(onnx.Mul(C, scale), bias), // alpha, beta, transA, transB) // -// This transformation is only valid when transB = 0 -// to maintain the correct computation shape alignment. +// This transformation is only valid when: +// - transB = 0 (to maintain correct shape alignment) +// - mean is 0 +// - var is 1 +// - epsilon is 0 // //===----------------------------------------------------------------------===// @@ -958,6 +967,18 @@ def isTransBFalse : Constraint, "TransB is 1 not 0" >; +def meanIsZero : Constraint< + CPred<"onnx_mlir::isZeroTensorOrSplat($0)">, "mean must be 0" +>; + +def varIsOne : Constraint< + CPred<"onnx_mlir::isOneTensorOrSplat($0)">, "var must be 1" +>; + +def epsIsZero : Constraint< + CPred<"onnx_mlir::isZeroAttrOrZeroTensor($0)">, "epsilon must be 0" +>; + def BackwardFoldScaleAxisToGemmPattern : Pat< (ONNXBatchNormalizationInferenceModeOp:$res (ONNXGemmOp $A, $B, $C, $alpha, $beta, $transA, $transB), @@ -969,7 +990,12 @@ def BackwardFoldScaleAxisToGemmPattern : Pat< (ONNXMulOp $C, $scale), $bias), (GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)), - [(isTransBFalse $transB)], + [ + (isTransBFalse $transB), + (meanIsZero $_mean), + (varIsOne $_var), + (epsIsZero $_epsilon) + ], [], (addBenefit 1) >; diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index d86c342512..e4bd90babb 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -817,10 +817,10 @@ func.func @test_backward_fold_scale_axis(%arg0: tensor<1x256xf32>) -> tensor<1x1 %1 = onnx.Constant dense<0.00999999977> : tensor<128xf32> %2 = onnx.Constant dense<0.00999999977> : tensor<128xf32> %3 = onnx.Constant dense<0.00999999977> : tensor<128xf32> - %4 = onnx.Constant dense<0.00999999977> : tensor<128xf32> - %5 = onnx.Constant dense<0.00999999977> : tensor<128xf32> + %4 = onnx.Constant dense<0.0> : tensor<128xf32> + %5 = onnx.Constant dense<1.0> : tensor<128xf32> %6 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "onnx.Gemm_0", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32> - %7 = "onnx.BatchNormalizationInferenceMode"(%6, %2, %3, %4, %5) {epsilon = 9.99999974E-6 : f32, momentum = 0.899999976 : f32, onnx_node_name = "onnx.BatchNormalizationInferenceMode_1"} : (tensor<1x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor<1x128xf32> + %7 = "onnx.BatchNormalizationInferenceMode"(%6, %2, %3, %4, %5) {epsilon = 0.0 : f32, momentum = 0.899999976 : f32, onnx_node_name = "onnx.BatchNormalizationInferenceMode_1"} : (tensor<1x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor<1x128xf32> %8 = "onnx.Relu"(%7) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32> return %8 : tensor<1x128xf32> // CHECK-LABEL: func @test_backward_fold_scale_axis