@@ -1970,3 +1970,121 @@ func.func @generic_op(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xf3
19701970 %0 = " test_dialect.op" (%arg0 , %arg1 ) : (tensor <2 xf32 >, tensor <2 xf32 >) -> (tensor <2 xf32 >)
19711971 return %0 : tensor <2 xf32 >
19721972}
1973+
1974+
1975+ // -----
1976+
1977+ /////////
1978+ // BatchNormInferenceOp
1979+
1980+ // CHECK-LABEL: @fuse_conv_bninf
1981+ func.func @fuse_conv_bninf () -> (tensor <1 x8 x5 x5 xf32 >) {
1982+ %input = stablehlo.constant dense <33.0 > : tensor <1 x3 x8 x8 xf32 >
1983+ %kernel = stablehlo.constant dense <0.1 > : tensor <8 x3 x4 x4 xf32 >
1984+ %conv = stablehlo.convolution (%input , %kernel )
1985+ dim_numbers = [b , f , 0 , 1 ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ],
1986+ window = {}
1987+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
1988+ : (tensor <1 x3 x8 x8 xf32 >, tensor <8 x3 x4 x4 xf32 >) -> tensor <1 x8 x5 x5 xf32 >
1989+
1990+ %dummy = stablehlo.constant dense <1.0 > : tensor <8 xf32 >
1991+ %out = " stablehlo.batch_norm_inference" (%conv , %dummy , %dummy , %dummy , %dummy )
1992+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
1993+ : (tensor <1 x8 x5 x5 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >)
1994+ -> tensor <1 x8 x5 x5 xf32 >
1995+
1996+ // CHECK-DAG: [[C0:%.+]] = stablehlo.convolution
1997+ // CHECK-DAG: [[C1:%.+]] = stablehlo.broadcast_in_dim
1998+ // CHECK-NOT: stablehlo.batch_norm_inference
1999+ // CHECK: [[C2:%.+]] = stablehlo.add [[C0]], [[C1]]
2000+ // CHECK: return [[C2]]
2001+ return %out : tensor <1 x8 x5 x5 xf32 >
2002+ }
2003+
2004+ // CHECK-LABEL: @fuse_conv_bninf_unsupported_group
2005+ func.func @fuse_conv_bninf_unsupported_group ()
2006+ -> (tensor <1 x8 x5 x5 xf32 >, tensor <1 x8 x5 x5 xf32 >) {
2007+ %input1 = stablehlo.constant dense <33.0 > : tensor <2 x3 x8 x8 xf32 >
2008+ %kernel1 = stablehlo.constant dense <0.1 > : tensor <8 x3 x4 x4 xf32 >
2009+ %conv1 = stablehlo.convolution (%input1 , %kernel1 )
2010+ dim_numbers = [b , f , 0 , 1 ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2011+ {batch_group_count = 2 : i64 , feature_group_count = 1 : i64 }
2012+ : (tensor <2 x3 x8 x8 xf32 >, tensor <8 x3 x4 x4 xf32 >) -> tensor <1 x8 x5 x5 xf32 >
2013+
2014+ %input2 = stablehlo.constant dense <33.0 > : tensor <1 x6 x8 x8 xf32 >
2015+ %kernel2 = stablehlo.constant dense <0.1 > : tensor <8 x3 x4 x4 xf32 >
2016+ %conv2 = stablehlo.convolution (%input2 , %kernel2 )
2017+ dim_numbers = [b , f , 0 , 1 ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2018+ {batch_group_count = 1 : i64 , feature_group_count = 2 : i64 }
2019+ : (tensor <1 x6 x8 x8 xf32 >, tensor <8 x3 x4 x4 xf32 >) -> tensor <1 x8 x5 x5 xf32 >
2020+
2021+ %cst = stablehlo.constant dense <1.0 > : tensor <8 xf32 >
2022+ %out1 = " stablehlo.batch_norm_inference" (%conv1 , %cst , %cst , %cst , %cst )
2023+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2024+ : (tensor <1 x8 x5 x5 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >)
2025+ -> tensor <1 x8 x5 x5 xf32 >
2026+
2027+ %out2 = " stablehlo.batch_norm_inference" (%conv2 , %cst , %cst , %cst , %cst )
2028+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2029+ : (tensor <1 x8 x5 x5 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >, tensor <8 xf32 >)
2030+ -> tensor <1 x8 x5 x5 xf32 >
2031+
2032+ // CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
2033+ // CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
2034+ // CHECK: return [[C0]], [[C1]]
2035+ return %out1 , %out2 : tensor <1 x8 x5 x5 xf32 >, tensor <1 x8 x5 x5 xf32 >
2036+ }
2037+
2038+ // CHECK-LABEL: @fuse_conv_bninf_unsupported_configuration
2039+ func.func @fuse_conv_bninf_unsupported_configuration ()
2040+ -> (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) {
2041+ %input = stablehlo.constant dense <33.0 > : tensor <1 x1 x1 x1 xf32 >
2042+ %kernel = stablehlo.constant dense <0.1 > : tensor <1 x1 x1 x1 xf32 >
2043+
2044+ %conv1 = stablehlo.convolution (%input , %kernel )
2045+ dim_numbers = [f , b , 0 , 1 ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2046+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
2047+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
2048+
2049+ %conv2 = stablehlo.convolution (%input , %kernel )
2050+ dim_numbers = [0 , 1 , f , b ]x [o , i , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2051+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
2052+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
2053+
2054+ %conv3 = stablehlo.convolution (%input , %kernel )
2055+ dim_numbers = [b , f , 0 , 1 ]x [i , o , 0 , 1 ]->[b , f , 0 , 1 ], window = {}
2056+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
2057+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
2058+
2059+ %conv4 = stablehlo.convolution (%input , %kernel )
2060+ dim_numbers = [b , f , 0 , 1 ]x [0 , 1 , o , i ]->[b , f , 0 , 1 ], window = {}
2061+ {batch_group_count = 1 : i64 , feature_group_count = 1 : i64 }
2062+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
2063+
2064+ %cst = stablehlo.constant dense <1.0 > : tensor <1 xf32 >
2065+
2066+ %out1 = " stablehlo.batch_norm_inference" (%conv1 , %cst , %cst , %cst , %cst )
2067+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2068+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >)
2069+ -> tensor <1 x1 x1 x1 xf32 >
2070+ %out2 = " stablehlo.batch_norm_inference" (%conv2 , %cst , %cst , %cst , %cst )
2071+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2072+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >)
2073+ -> tensor <1 x1 x1 x1 xf32 >
2074+ %out3 = " stablehlo.batch_norm_inference" (%conv3 , %cst , %cst , %cst , %cst )
2075+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2076+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >)
2077+ -> tensor <1 x1 x1 x1 xf32 >
2078+ %out4 = " stablehlo.batch_norm_inference" (%conv4 , %cst , %cst , %cst , %cst )
2079+ <{epsilon = 1.0E-6 : f32 , feature_index = 1 : i64 }>
2080+ : (tensor <1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >)
2081+ -> tensor <1 x1 x1 x1 xf32 >
2082+
2083+ // CHECK: [[C0:%.+]] = "stablehlo.batch_norm_inference"
2084+ // CHECK: [[C1:%.+]] = "stablehlo.batch_norm_inference"
2085+ // CHECK: [[C2:%.+]] = "stablehlo.batch_norm_inference"
2086+ // CHECK: [[C3:%.+]] = "stablehlo.batch_norm_inference"
2087+ // CHECK: return [[C0]], [[C1]], [[C2]], [[C3]]
2088+ return %out1 , %out2 , %out3 , %out4 : tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >,
2089+ tensor <1 x1 x1 x1 xf32 >, tensor <1 x1 x1 x1 xf32 >
2090+ }
0 commit comments