@@ -1981,3 +1981,121 @@ func.func @generic_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf3
19811981 %0 = " test_dialect.op" (%arg0 , %arg1 ) : (tensor <2 xf32 >, tensor <2 xf32 >) -> (tensor <2 xf32 >)
19821982 return %0 : tensor <2 xf32 >
19831983}
1984+
1985+
1986+ // -----
1987+
1988+ /////////
1989+ // BatchNormInferenceOp
1990+
1991+ // CHECK-LABEL: @fuse_conv_bninf
1992+ func.func @fuse_conv_bninf () -> (tensor <1 x8 x5 x5 xf32 >) {
1993+ %input = stablehlo.constant dense <33.0 > : tensor <1 x3 x8 x8 xf32 >
1994+ %kernel = stablehlo.constant dense <0.1 > : tensor <8 x3 x4 x4 xf32 >
1995+ %conv = stablehlo.convolution (%input , %kernel )
1996+ dim_numbers = [b , f , 0 , 1 ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ],
1997+ window = {}
1998+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
1999+ : (tensor <1 x3 x8 x8 xf32 >, tensor <8 x3 x4 x4 xf32 >) -> tensor <1 x8 x5 x5 xf32 >
2000+
2001+ %dummy = stablehlo.constant dense <1.0 > : tensor <8 xf32 >
2002+ %out = " stablehlo.batch_norm_inference" (%conv , %dummy , %dummy , %dummy , %dummy )
2003+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2004+ : (tensor <1 x8 x5 x5 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >)
2005+ -> tensor <1 x8 x5 x5 xf32 >
2006+
2007+ // CHECK-DAG: [[C0:%.+]] = stablehlo.convolution
2008+ // CHECK-DAG: [[C1:%.+]] = stablehlo.broadcast_in_dim
2009+ // CHECK-NOT: stablehlo.batch_norm_inference
2010+ // CHECK: [[C2:%.+]] = stablehlo.add [[C0]], [[C1]]
2011+ // CHECK: return [[C2]]
2012+ return %out : tensor <1 x8 x5 x5 xf32 >
2013+ }
2014+
2015+ // CHECK-LABEL: @fuse_conv_bninf_unsupported_group
2016+ func.func @fuse_conv_bninf_unsupported_group ()
2017+ -> (tensor <1 x8 x5 x5 xf32 >, tensor <1 x8 x5 x5 xf32 >) {
2018+ %input1 = stablehlo.constant dense <33.0 > : tensor <2 x3 x8 x8 xf32 >
2019+ %kernel1 = stablehlo.constant dense <0.1 > : tensor <8 x3 x4 x4 xf32 >
2020+ %conv1 = stablehlo.convolution (%input1 , %kernel1 )
2021+ dim_numbers = [b , f , 0 , 1 ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2022+ {batch_group_count = 2 : i64 , feature_group_count = 1 : i64 }
2023+ : (tensor <2 x3 x8 x8 xf32 >, tensor <8 x3 x4 x4 xf32 >) -> tensor <1 x8 x5 x5 xf32 >
2024+
2025+ %input2 = stablehlo.constant dense <33.0 > : tensor <1 x6 x8 x8 xf32 >
2026+ %kernel2 = stablehlo.constant dense <0.1 > : tensor <8 x3 x4 x4 xf32 >
2027+ %conv2 = stablehlo.convolution (%input2 , %kernel2 )
2028+ dim_numbers = [b , f , 0 , 1 ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2029+ {batch_group_count = 1 : i64 , feature_group_count = 2 : i64 }
2030+ : (tensor <1 x6 x8 x8 xf32 >, tensor <8 x3 x4 x4 xf32 >) -> tensor <1 x8 x5 x5 xf32 >
2031+
2032+ %cst = stablehlo.constant dense <1.0 > : tensor <8 xf32 >
2033+ %out1 = " stablehlo.batch_norm_inference" (%conv1 , %cst , %cst , %cst , %cst )
2034+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2035+ : (tensor <1 x8 x5 x5 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >)
2036+ -> tensor <1 x8 x5 x5 xf32 >
2037+
2038+ %out2 = " stablehlo.batch_norm_inference" (%conv2 , %cst , %cst , %cst , %cst )
2039+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2040+ : (tensor <1 x8 x5 x5 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >)
2041+ -> tensor <1 x8 x5 x5 xf32 >
2042+
2043+ // CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
2044+ // CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
2045+ // CHECK: return [[C0]], [[C1]]
2046+ return %out1 , %out2 : tensor <1 x8 x5 x5 xf32 >, tensor <1 x8 x5 x5 xf32 >
2047+ }
2048+
2049+ // CHECK-LABEL: @fuse_conv_bninf_unsupported_configuration
2050+ func.func @fuse_conv_bninf_unsupported_configuration ()
2051+ -> (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) {
2052+ %input = stablehlo.constant dense <33.0 > : tensor <1 x1 x1 x1 xf32 >
2053+ %kernel = stablehlo.constant dense <0.1 > : tensor <1 x1 x1 x1 xf32 >
2054+
2055+ %conv1 = stablehlo.convolution (%input , %kernel )
2056+ dim_numbers = [f , b , 0 , 1 ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2057+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
2058+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
2059+
2060+ %conv2 = stablehlo.convolution (%input , %kernel )
2061+ dim_numbers = [0 , 1 , f , b ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2062+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
2063+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
2064+
2065+ %conv3 = stablehlo.convolution (%input , %kernel )
2066+ dim_numbers = [b , f , 0 , 1 ]x [i , o , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2067+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
2068+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
2069+
2070+ %conv4 = stablehlo.convolution (%input , %kernel )
2071+ dim_numbers = [b , f , 0 , 1 ]x [0 , 1 , o , i ]->[b , f , 0 , 1 ], window = {}
2072+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
2073+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
2074+
2075+ %cst = stablehlo.constant dense <1.0 > : tensor <1 xf32 >
2076+
2077+ %out1 = " stablehlo.batch_norm_inference" (%conv1 , %cst , %cst , %cst , %cst )
2078+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2079+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >)
2080+ -> tensor <1 x1 x1 x1 xf32 >
2081+ %out2 = " stablehlo.batch_norm_inference" (%conv2 , %cst , %cst , %cst , %cst )
2082+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2083+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >)
2084+ -> tensor <1 x1 x1 x1 xf32 >
2085+ %out3 = " stablehlo.batch_norm_inference" (%conv3 , %cst , %cst , %cst , %cst )
2086+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2087+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >)
2088+ -> tensor <1 x1 x1 x1 xf32 >
2089+ %out4 = " stablehlo.batch_norm_inference" (%conv4 , %cst , %cst , %cst , %cst )
2090+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2091+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >)
2092+ -> tensor <1 x1 x1 x1 xf32 >
2093+
2094+ // CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
2095+ // CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
2096+ // CHECK: [[C2:%.+]] = "stablehlo.batch_norm_inference"
2097+ // CHECK: [[C3:%.+]] = "stablehlo.batch_norm_inference"
2098+ // CHECK: return [[C0]], [[C1]], [[C2]], [[C3]]
2099+ return %out1 , %out2 , %out3 , %out4 : tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >,
2100+ tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >
2101+ }
0 commit comments