Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ bool isNotConvProducer(mlir::Value val) {
return true; // If no defining op, assume it's safe
}

bool isTransBFalse(mlir::Attribute attr) {
if (auto intAttr = attr.dyn_cast<mlir::IntegerAttr>()) {
int64_t val =
intAttr.getValue().getSExtValue(); // safe for signless integers
return val == 0; // return true if transB is false (0)
}
return false; // default fallback
}

// Get the index of the axis value in the given permutation array.
IntegerAttr getIndexOfAxisInPerm(
PatternRewriter &rewriter, ArrayAttr permAttr, IntegerAttr axis) {
Expand Down Expand Up @@ -1685,6 +1694,7 @@ void ONNXBatchNormalizationInferenceModeOp::getCanonicalizationPatterns(
results.insert<FuseBatchNormInferenceModeConvPattern>(context);
results.insert<RewriteBatchNormInferenceModeConvPattern1>(context);
results.insert<RewriteBatchNormInferenceModeConvPattern2>(context);
results.insert<BackwardFoldScaleAxisToGemmPattern>(context);
}

/// on the ONNXAddOp.
Expand Down
48 changes: 48 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.td
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,54 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat<
[(HasRankOf<1> $x)], [], (addBenefit 0)
>;

//===----------------------------------------------------------------------===//
// This is to fold the composition: 'BatchNormalization o Gemm' into 'Gemm'
// by deriving new 'B' and 'C' for 'Gemm' operation after fusing scale and bias.
//
// Given:
// (Gemm) Z = A * B + C
// (BatchNorm) Y = scale * (Z - mean) / sqrt(var + eps) + bias
//
// This transformation corresponds to a recomposition:
// Y = A * (scale * B) + (scale * bias + C)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Arkar-Hema could you elaborate how you derived this formula where mean, var and eps are canceled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assume mean=0, var=1 and exp=0 (which is usually present in any pre-compiled and normalised models):

Y ≈ scale × Z + bias
Substituting Z = A x B + C
Y ≈ scale × (A × B + C) + bias
Y = A × (scale × B) + (scale × C + bias)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assume mean=0, var=1 and exp=0 (which is usually present in any pre-compiled and normalised models):

Then, you have to define this assumption in the constraint part of the rewriting rule. Otherwise, the rewriting rule produces a wrong result.

Anyway, my recommendation is to handle the general case where mean, var, eps are constants (not necessary concrete values, say, of 0, 1 and 0, respectively). New scale and bias values for matmul can be easily computed from mean, var, eps, scale and bias of BatchNorm, and in the inference mode, these values are constants and will be folded automatically by the compiler into a single constant.

//
// Therefore, we rewrite:
// onnx.BatchNormalizationInferenceMode(
// onnx.Gemm(A, B, C, alpha, beta, transA, transB),
// scale, bias, mean, var
// ) {epsilon = ..., momentum = ...}
//
// as:
// onnx.Gemm(
// A,
// onnx.Mul(B, scale),
// onnx.Add(onnx.Mul(bias, scale), C),
// alpha, beta, transA, transB)
//
// This transformation is only valid when transB = 0
// to maintain the correct computation shape alignment.
//
//===----------------------------------------------------------------------===//

def isTransBFalse : Constraint<CPred<
"onnx_mlir::isTransBFalse($0)">, "TransB is 1 not 0"
>;

def BackwardFoldScaleAxisToGemmPattern : Pat<
(ONNXBatchNormalizationInferenceModeOp:$res
(ONNXGemmOp $A, $B, $C, $alpha, $beta, $transA, $transB),
$scale, $bias, $_mean, $_var, $_epsilon, $_momentum),
(ONNXGemmOp
$A,
(ONNXMulOp $B, $scale),
(ONNXAddOp
(ONNXMulOp $bias, $scale),
$C),
(GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
[(isTransBFalse $transB)],
[], (addBenefit 1)
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXShapeOp
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,32 @@ func.func @test_rewrite_batchnormtestmode_1d_f16(%arg0 : tensor<64xf16>, %scale

// -----

func.func @test_backward_fold_scale_axis(%arg0: tensor<1x256xf32>) -> tensor<1x128xf32> {
%0 = onnx.Constant dense<0.00999999977> : tensor<256x128xf32>
%1 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%2 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%3 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%4 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%5 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%6 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "onnx.Gemm_0", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%7 = "onnx.BatchNormalizationInferenceMode"(%6, %2, %3, %4, %5) {epsilon = 9.99999974E-6 : f32, momentum = 0.899999976 : f32, onnx_node_name = "onnx.BatchNormalizationInferenceMode_1"} : (tensor<1x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%8 = "onnx.Relu"(%7) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %8 : tensor<1x128xf32>
// CHECK-LABEL: func @test_backward_fold_scale_axis
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x256xf32>) -> tensor<1x128xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<256x128xf32>
// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<128xf32>
// CHECK: [[MUL_0_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_1_]])
// CHECK: [[MUL_1_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]])
// CHECK: [[ADD_0_:%.+]] = "onnx.Add"([[MUL_1_]], [[VAR_1_]])
// CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[MUL_0_]], [[ADD_0_]])
// CHECK-SAME: : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
// CHECK: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_2_]]) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32>
// CHECK-NEXT: return [[VAR_3_]] : tensor<1x128xf32>
}

// -----

func.func @test_normalize_add(%arg0 : tensor<2xf32>) -> tensor<2xf32> {
%cst = "onnx.NoValue"() {value} : () -> none
%0 = onnx.Constant dense<[0.0, 1.0]> : tensor<2xf32>
Expand Down
Loading