Skip to content

Commit a05ff57

Browse files
committed
add transformation tests
1 parent a80f311 commit a05ff57

File tree

2 files changed

+121
-3
lines changed

2 files changed

+121
-3
lines changed

stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,3 +1970,121 @@ func.func @generic_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf3
19701970
%0 = "test_dialect.op"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xf32>)
19711971
return %0 : tensor<2xf32>
19721972
}
1973+
1974+
1975+
// -----
1976+
1977+
/////////
1978+
// BatchNormInferenceOp
1979+
1980+
// CHECK-LABEL: @fuse_conv_bninf
1981+
func.func @fuse_conv_bninf() -> (tensor<1x8x5x5xf32>) {
1982+
%input = stablehlo.constant dense<33.0> : tensor<1x3x8x8xf32>
1983+
%kernel = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
1984+
%conv = stablehlo.convolution(%input, %kernel)
1985+
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1],
1986+
window = {}
1987+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
1988+
: (tensor<1x3x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>
1989+
1990+
%dummy = stablehlo.constant dense<1.0> : tensor<8xf32>
1991+
%out = "stablehlo.batch_norm_inference"(%conv, %dummy, %dummy, %dummy, %dummy)
1992+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
1993+
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
1994+
-> tensor<1x8x5x5xf32>
1995+
1996+
// CHECK-DAG: [[C0:%.+]] = stablehlo.convolution
1997+
// CHECK-DAG: [[C1:%.+]] = stablehlo.broadcast_in_dim
1998+
// CHECK-NOT: stablehlo.batch_norm_inference
1999+
// CHECK: [[C2:%.+]] = stablehlo.add [[C0]], [[C1]]
2000+
// CHECK: return [[C2]]
2001+
return %out : tensor<1x8x5x5xf32>
2002+
}
2003+
2004+
// CHECK-LABEL: @fuse_conv_bninf_unsupported_group
2005+
func.func @fuse_conv_bninf_unsupported_group()
2006+
-> (tensor<1x8x5x5xf32>, tensor<1x8x5x5xf32>) {
2007+
%input1 = stablehlo.constant dense<33.0> : tensor<2x3x8x8xf32>
2008+
%kernel1 = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
2009+
%conv1 = stablehlo.convolution(%input1, %kernel1)
2010+
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
2011+
{batch_group_count = 2 : i64, feature_group_count = 1 : i64}
2012+
: (tensor<2x3x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>
2013+
2014+
%input2 = stablehlo.constant dense<33.0> : tensor<1x6x8x8xf32>
2015+
%kernel2 = stablehlo.constant dense<0.1> : tensor<8x3x4x4xf32>
2016+
%conv2 = stablehlo.convolution(%input2, %kernel2)
2017+
dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
2018+
{batch_group_count = 1 : i64, feature_group_count = 2 : i64}
2019+
: (tensor<1x6x8x8xf32>, tensor<8x3x4x4xf32>) -> tensor<1x8x5x5xf32>
2020+
2021+
%cst = stablehlo.constant dense<1.0> : tensor<8xf32>
2022+
%out1 = "stablehlo.batch_norm_inference"(%conv1, %cst, %cst, %cst, %cst)
2023+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2024+
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
2025+
-> tensor<1x8x5x5xf32>
2026+
2027+
%out2 = "stablehlo.batch_norm_inference"(%conv2, %cst, %cst, %cst, %cst)
2028+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2029+
: (tensor<1x8x5x5xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>)
2030+
-> tensor<1x8x5x5xf32>
2031+
2032+
// CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
2033+
// CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
2034+
// CHECK: return [[C0]], [[C1]]
2035+
return %out1, %out2 : tensor<1x8x5x5xf32>, tensor<1x8x5x5xf32>
2036+
}
2037+
2038+
// CHECK-LABEL: @fuse_conv_bninf_unsupported_configuration
2039+
func.func @fuse_conv_bninf_unsupported_configuration()
2040+
-> (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) {
2041+
%input = stablehlo.constant dense<33.0> : tensor<1x1x1x1xf32>
2042+
%kernel = stablehlo.constant dense<0.1> : tensor<1x1x1x1xf32>
2043+
2044+
%conv1 = stablehlo.convolution(%input, %kernel)
2045+
dim_numbers = [f, b, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
2046+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
2047+
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
2048+
2049+
%conv2 = stablehlo.convolution(%input, %kernel)
2050+
dim_numbers = [0, 1, f, b]x[o, i, 0, 1]->[b, f, 0, 1], window = {}
2051+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
2052+
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
2053+
2054+
%conv3 = stablehlo.convolution(%input, %kernel)
2055+
dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {}
2056+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
2057+
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
2058+
2059+
%conv4 = stablehlo.convolution(%input, %kernel)
2060+
dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {}
2061+
{batch_group_count = 1 : i64, feature_group_count = 1 : i64}
2062+
: (tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32>
2063+
2064+
%cst = stablehlo.constant dense<1.0> : tensor<1xf32>
2065+
2066+
%out1 = "stablehlo.batch_norm_inference"(%conv1, %cst, %cst, %cst, %cst)
2067+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2068+
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
2069+
-> tensor<1x1x1x1xf32>
2070+
%out2 = "stablehlo.batch_norm_inference"(%conv2, %cst, %cst, %cst, %cst)
2071+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2072+
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
2073+
-> tensor<1x1x1x1xf32>
2074+
%out3 = "stablehlo.batch_norm_inference"(%conv3, %cst, %cst, %cst, %cst)
2075+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2076+
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
2077+
-> tensor<1x1x1x1xf32>
2078+
%out4 = "stablehlo.batch_norm_inference"(%conv4, %cst, %cst, %cst, %cst)
2079+
<{epsilon = 1.0E-6 : f32, feature_index = 1 : i64}>
2080+
: (tensor<1x1x1x1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>)
2081+
-> tensor<1x1x1x1xf32>
2082+
2083+
// CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
2084+
// CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
2085+
// CHECK: [[C2:%.+]] = "stablehlo.batch_norm_inference"
2086+
// CHECK: [[C3:%.+]] = "stablehlo.batch_norm_inference"
2087+
// CHECK: return [[C0]], [[C1]], [[C2]], [[C3]]
2088+
return %out1, %out2, %out3, %out4 : tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>,
2089+
tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>
2090+
}

stablehlo/transforms/StablehloAggressiveSimplification.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1496,10 +1496,10 @@ struct FuseConvolutionBatchNormalization final
14961496
auto dimNumbers = convOp.getDimensionNumbers();
14971497
if (dimNumbers.getInputBatchDimension() != 0 ||
14981498
dimNumbers.getInputFeatureDimension() != 1 ||
1499-
dimNumbers.getOutputBatchDimension() != 0 ||
1500-
dimNumbers.getOutputFeatureDimension() != 1 ||
15011499
dimNumbers.getKernelOutputFeatureDimension() != 0 ||
1502-
dimNumbers.getKernelInputFeatureDimension() != 1) {
1500+
dimNumbers.getKernelInputFeatureDimension() != 1 ||
1501+
dimNumbers.getOutputBatchDimension() != 0 ||
1502+
dimNumbers.getOutputFeatureDimension() != 1) {
15031503
constexpr StringLiteral msg =
15041504
"Only [b, f, ...]x[o, i, ...]->[b, f, ...] configuration is "
15051505
"supported";

0 commit comments

Comments
 (0)