@@ -76,3 +76,25 @@ func.func @depthwise_conv2d_as_mul_padded(%arg0: tensor<4x10x10x2xf32>, %arg1: t
7676 %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %zp , %zp {acc_type = f32 , pad = array<i64 : 1 , 1 , 1 , 1 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 1 , 1 >} : (tensor <4 x10 x10 x2 xf32 >, tensor <1 x1 x2 x3 xf32 >, tensor <6 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <4 x12 x12 x6 xf32 >
7777 return %0 : tensor <4 x12 x12 x6 xf32 >
7878}
79+
80+ // -----
81+
82+ // Decompose only support integer or float types.
83+
84+ // CHECK-LABEL: @depthwise_conv2d_quant_type
85+ func.func @depthwise_conv2d_quant_type (%arg0: tensor <4 x10 x10 x2 x!quant.uniform <i8 :f32 , 0.015684768557548523 >>, %arg1: tensor <1 x1 x2 x3 x!quant.uniform <i8 <-127 :127 >:f32 , 0.015680249780416489 >>, %arg2: tensor <6 xi32 >) -> tensor <4 x10 x10 x6 x!quant.uniform <i32 :f32 , 0.078431375324726104 >> {
86+ %0 = " tosa.const" () <{value = dense <7 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
87+ %1 = " tosa.const" () <{value = dense <11 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
88+ // CHECK: tosa.depthwise_conv2d
89+ %2 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %0 , %1 {acc_type = i32 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <4 x10 x10 x2 x!quant.uniform <i8 :f32 , 0.015684768557548523 >>, tensor <1 x1 x2 x3 x!quant.uniform <i8 <-127 :127 >:f32 , 0.015680249780416489 >>, tensor <6 xi32 >, tensor <1 xi8 >, tensor <1 xi8 >) -> tensor <4 x10 x10 x6 x!quant.uniform <i32 :f32 , 0.078431375324726104 >>
90+ return %2 : tensor <4 x10 x10 x6 x!quant.uniform <i32 :f32 , 0.078431375324726104 >>
91+ }
92+
93+ // -----
94+
95+ // CHECK-LABEL: @depthwise_conv2d_no_const_zero_point
96+ func.func @depthwise_conv2d_no_const_zero_point (%arg0: tensor <4 x10 x10 x2 xi8 >, %arg1: tensor <1 x1 x2 x3 xi8 >, %arg2: tensor <6 xi32 >, %arg3: tensor <1 xi8 >, %arg4: tensor <1 xi8 >) -> tensor <4 x10 x10 x6 xi32 > {
97+ // CHECK: tosa.depthwise_conv2d
98+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = i32 , pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 1 , 1 >} : (tensor <4 x10 x10 x2 xi8 >, tensor <1 x1 x2 x3 xi8 >, tensor <6 xi32 >, tensor <1 xi8 >, tensor <1 xi8 >) -> tensor <4 x10 x10 x6 xi32 >
99+ return %0 : tensor <4 x10 x10 x6 xi32 >
100+ }
0 commit comments