Skip to content

Commit ea83d96

Browse files
committed
add transformation tests
1 parent 2525a23 commit ea83d96

File tree

2 files changed

+121
-3
lines changed

2 files changed

+121
-3
lines changed

stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,3 +1981,121 @@ func.func @generic_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf3
19811981
%0 = "test_dialect.op"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xf32>)
19821982
return %0 : tensor<2xf32>
19831983
}
1984+
1985+
1986+
// -----
1987+
1988+
/////////
1989+
// BatchNormInferenceOp
1990+
1991+
// CHECK-LABEL: @fuse_conv_bninf
1992+
func.func @fuse_conv_bninf() -> (tensor<1x8x5x5xf32>) {
1993+
%input = stablehlo.constant dense<33.0> : tensor<1x3x8x8xf32>
1994+
%kernel = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
1995+
%conv = stablehlo.convolution(%input, %kernel)
1996+
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
1997+
window = {}
1998+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
1999+
: (tensor<1x3x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>
2000+
2001+
%dummy = stablehlo.constant dense<1.0> : tensor<8xf32>
2002+
%out = "stablehlo.batch_norm_inference"(%conv, %dummy, %dummy, %dummy, %dummy)
2003+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2004+
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
2005+
-> tensor<1x8x5x5xf32>
2006+
2007+
// CHECK-DAG: [[C0:%.+]] = stablehlo.convolution
2008+
// CHECK-DAG: [[C1:%.+]] = stablehlo.broadcast_in_dim
2009+
// CHECK-NOT: stablehlo.batch_norm_inference
2010+
// CHECK: [[C2:%.+]] = stablehlo.add [[C0]], [[C1]]
2011+
// CHECK: return [[C2]]
2012+
return %out : tensor<1x8x5x5xf32>
2013+
}
2014+
2015+
// CHECK-LABEL: @fuse_conv_bninf_unsupported_group
2016+
func.func @fuse_conv_bninf_unsupported_group()
2017+
-> (tensor<1x8x5x5xf32>, tensor<1x8x5x5xf32>) {
2018+
%input1 = stablehlo.constant dense<33.0> : tensor<2x3x8x8xf32>
2019+
%kernel1 = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
2020+
%conv1 = stablehlo.convolution(%input1, %kernel1)
2021+
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
2022+
{batch_group_count = 2 : i64, feature_group_count = 1 : i64}
2023+
: (tensor<2x3x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>
2024+
2025+
%input2 = stablehlo.constant dense<33.0> : tensor<1x6x8x8xf32>
2026+
%kernel2 = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
2027+
%conv2 = stablehlo.convolution(%input2, %kernel2)
2028+
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
2029+
{batch_group_count = 1 : i64, feature_group_count = 2 : i64}
2030+
: (tensor<1x6x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>
2031+
2032+
%cst = stablehlo.constant dense<1.0> : tensor<8xf32>
2033+
%out1 = "stablehlo.batch_norm_inference"(%conv1, %cst, %cst, %cst, %cst)
2034+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2035+
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
2036+
-> tensor<1x8x5x5xf32>
2037+
2038+
%out2 = "stablehlo.batch_norm_inference"(%conv2, %cst, %cst, %cst, %cst)
2039+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2040+
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
2041+
-> tensor<1x8x5x5xf32>
2042+
2043+
// CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
2044+
// CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
2045+
// CHECK: return [[C0]], [[C1]]
2046+
return %out1, %out2 : tensor<1x8x5x5xf32>, tensor<1x8x5x5xf32>
2047+
}
2048+
2049+
// CHECK-LABEL: @fuse_conv_bninf_unsupported_configuration
2050+
func.func @fuse_conv_bninf_unsupported_configuration()
2051+
-> (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) {
2052+
%input = stablehlo.constant dense<33.0> : tensor<1x1x1x1xf32>
2053+
%kernel = stablehlo.constant dense<0.1> : tensor<1x1x1x1xf32>
2054+
2055+
%conv1 = stablehlo.convolution(%input, %kernel)
2056+
dim_numbers = [f, b, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
2057+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
2058+
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
2059+
2060+
%conv2 = stablehlo.convolution(%input, %kernel)
2061+
dim_numbers = [0, 1, f, b]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
2062+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
2063+
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
2064+
2065+
%conv3 = stablehlo.convolution(%input, %kernel)
2066+
dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {}
2067+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
2068+
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
2069+
2070+
%conv4 = stablehlo.convolution(%input, %kernel)
2071+
dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {}
2072+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
2073+
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
2074+
2075+
%cst = stablehlo.constant dense<1.0> : tensor<1xf32>
2076+
2077+
%out1 = "stablehlo.batch_norm_inference"(%conv1, %cst, %cst, %cst, %cst)
2078+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2079+
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
2080+
-> tensor<1x1x1x1xf32>
2081+
%out2 = "stablehlo.batch_norm_inference"(%conv2, %cst, %cst, %cst, %cst)
2082+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2083+
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
2084+
-> tensor<1x1x1x1xf32>
2085+
%out3 = "stablehlo.batch_norm_inference"(%conv3, %cst, %cst, %cst, %cst)
2086+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2087+
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
2088+
-> tensor<1x1x1x1xf32>
2089+
%out4 = "stablehlo.batch_norm_inference"(%conv4, %cst, %cst, %cst, %cst)
2090+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2091+
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
2092+
-> tensor<1x1x1x1xf32>
2093+
2094+
// CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
2095+
// CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
2096+
// CHECK: [[C2:%.+]] = "stablehlo.batch_norm_inference"
2097+
// CHECK: [[C3:%.+]] = "stablehlo.batch_norm_inference"
2098+
// CHECK: return [[C0]], [[C1]], [[C2]], [[C3]]
2099+
return %out1, %out2, %out3, %out4 : tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>,
2100+
tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>
2101+
}

stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,10 +1505,10 @@ struct FuseConvolutionBatchNormalization final
15051505
auto dimNumbers = convOp.getDimensionNumbers();
15061506
if (dimNumbers.getInputBatchDimension() != 0 ||
15071507
dimNumbers.getInputFeatureDimension() != 1 ||
1508-
dimNumbers.getOutputBatchDimension() != 0 ||
1509-
dimNumbers.getOutputFeatureDimension() != 1 ||
15101508
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
1511-
dimNumbers.getKernelInputFeatureDimension() != 1) {
1509+
dimNumbers.getKernelInputFeatureDimension() != 1 ||
1510+
dimNumbers.getOutputBatchDimension() != 0 ||
1511+
dimNumbers.getOutputFeatureDimension() != 1) {
15121512
constexpr StringLiteral msg =
15131513
"Only [b, f, ...]x[o, i, ...]->[b, f, ...] configuration is "
15141514
"supported";

0 commit comments

Comments
 (0)