Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ bool isNotConvProducer(mlir::Value val) {
return true; // If no defining op, assume it's safe
}

bool isTransBFalse(mlir::Attribute attr) {
if (auto intAttr = attr.dyn_cast<mlir::IntegerAttr>()) {
int64_t val =
intAttr.getValue().getSExtValue(); // safe for signless integers
return val == 0; // return true if transB is false (0)
}
return false; // default fallback
}

// Get the index of the axis value in the given permutation array.
IntegerAttr getIndexOfAxisInPerm(
PatternRewriter &rewriter, ArrayAttr permAttr, IntegerAttr axis) {
Expand Down Expand Up @@ -1763,6 +1772,7 @@ void ONNXBatchNormalizationInferenceModeOp::getCanonicalizationPatterns(
results.insert<FuseBatchNormInferenceModeConvPattern>(context);
results.insert<RewriteBatchNormInferenceModeConvPattern1>(context);
results.insert<RewriteBatchNormInferenceModeConvPattern2>(context);
results.insert<BackwardFoldScaleAxisToGemmPattern>(context);
}

/// on the ONNXAddOp.
Expand Down
48 changes: 48 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Canonicalize.td
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,54 @@ def RewriteBatchNormInferenceModeConvPattern2: Pat<
[(HasRankOf<1> $x)], [], (addBenefit 0)
>;

//===----------------------------------------------------------------------===//
// This is to fold the composition: 'BatchNormalization o Gemm' into 'Gemm'
// by deriving new 'B' and 'C' for 'Gemm' operation after fusing scale and bias.
//
// Given:
// (Gemm) Z = A * B + C
// (BatchNorm) Y = scale * (Z - mean) / sqrt(var + eps) + bias
//
// This transformation corresponds to a recomposition:
// Y = A * (scale * B) + (scale * C + bias)
//
// Therefore, we rewrite:
// onnx.BatchNormalizationInferenceMode(
// onnx.Gemm(A, B, C, alpha, beta, transA, transB),
// scale, bias, mean, var
// ) {epsilon = ..., momentum = ...}
//
// as:
// onnx.Gemm(
// A,
// onnx.Mul(B, scale),
// onnx.Add(onnx.Mul(bias, scale), C),
// alpha, beta, transA, transB)
//
// This transformation is only valid when transB = 0
// to maintain the correct computation shape alignment.
//
//===----------------------------------------------------------------------===//

def isTransBFalse : Constraint<CPred<
"onnx_mlir::isTransBFalse($0)">, "TransB is 1 not 0"
>;

def BackwardFoldScaleAxisToGemmPattern : Pat<
(ONNXBatchNormalizationInferenceModeOp:$res
(ONNXGemmOp $A, $B, $C, $alpha, $beta, $transA, $transB),
$scale, $bias, $_mean, $_var, $_epsilon, $_momentum),
(ONNXGemmOp
$A,
(ONNXMulOp $B, $scale),
(ONNXAddOp
(ONNXMulOp $C, $scale),
$bias),
(GemmAlpha), (GemmBeta), (GemmTransA), (GemmTransB)),
[(isTransBFalse $transB)],
[], (addBenefit 1)
>;

//===----------------------------------------------------------------------===//
// Canonicalization for ONNXShapeOp
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions test/mlir/onnx/onnx_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,32 @@ func.func @test_rewrite_batchnormtestmode_1d_f16(%arg0 : tensor<64xf16>, %scale

// -----

func.func @test_backward_fold_scale_axis(%arg0: tensor<1x256xf32>) -> tensor<1x128xf32> {
%0 = onnx.Constant dense<0.00999999977> : tensor<256x128xf32>
%1 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%2 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%3 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%4 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%5 = onnx.Constant dense<0.00999999977> : tensor<128xf32>
%6 = "onnx.Gemm"(%arg0, %0, %1) {alpha = 1.000000e+00 : f32, beta = 1.000000e+00 : f32, onnx_node_name = "onnx.Gemm_0", transA = 0 : si64, transB = 0 : si64} : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%7 = "onnx.BatchNormalizationInferenceMode"(%6, %2, %3, %4, %5) {epsilon = 9.99999974E-6 : f32, momentum = 0.899999976 : f32, onnx_node_name = "onnx.BatchNormalizationInferenceMode_1"} : (tensor<1x128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
%8 = "onnx.Relu"(%7) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32>
return %8 : tensor<1x128xf32>
// CHECK-LABEL: func @test_backward_fold_scale_axis
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x256xf32>) -> tensor<1x128xf32> {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<256x128xf32>
// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<{{.*}}> : tensor<128xf32>
// CHECK: [[MUL_0_:%.+]] = "onnx.Mul"([[VAR_0_]], [[VAR_1_]])
// CHECK: [[MUL_1_:%.+]] = "onnx.Mul"([[VAR_1_]], [[VAR_1_]])
// CHECK: [[ADD_0_:%.+]] = "onnx.Add"([[MUL_1_]], [[VAR_1_]])
// CHECK: [[VAR_2_:%.+]] = "onnx.Gemm"([[PARAM_0_]], [[MUL_0_]], [[ADD_0_]])
// CHECK-SAME: : (tensor<1x256xf32>, tensor<256x128xf32>, tensor<128xf32>) -> tensor<1x128xf32>
// CHECK: [[VAR_3_:%.+]] = "onnx.Relu"([[VAR_2_]]) {onnx_node_name = "onnx.Relu_2"} : (tensor<1x128xf32>) -> tensor<1x128xf32>
// CHECK-NEXT: return [[VAR_3_]] : tensor<1x128xf32>
}

// -----

func.func @test_normalize_add(%arg0 : tensor<2xf32>) -> tensor<2xf32> {
%cst = "onnx.NoValue"() {value} : () -> none
%0 = onnx.Constant dense<[0.0, 1.0]> : tensor<2xf32>
Expand Down
Loading