@@ -672,6 +672,63 @@ func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<2
672672
673673// -----
674674
675+ // CHECK-LABEL: @conv2d_bias_broadcast_f32
676+ func.func @conv2d_bias_broadcast_f32 (%input: tensor <1 x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >) -> () {
677+ %bias = " tosa.const" () <{values = dense <4.20 > : tensor <28 xf32 >}> : () -> tensor <28 xf32 >
678+ // CHECK-DAG: %[[CST:.+]] = arith.constant 4.200000e+00 : f32
679+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
680+ // CHECK: %[[BIAS:.+]] = linalg.fill
681+ // CHECK-SAME: ins(%[[CST]]
682+ // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xf32>
683+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
684+ // CHECK-SAME: outs(%[[BIAS]]
685+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
686+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
687+ %0 = tosa.conv2d %input , %weights , %bias , %input_zp , %weight_zp {acc_type = f32 , pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <1 x49 x42 x27 xf32 >, tensor <28 x3 x3 x27 xf32 >, tensor <28 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x45 x40 x28 xf32 >
688+ return
689+ }
690+
691+ // -----
692+
693+ // CHECK-LABEL: @conv2d_dynamic_batch_bias_broadcast_f32
694+ // CHECK-SAME: (%[[INPUT:.+]]: tensor<?x49x42x27xf32>
695+ func.func @conv2d_dynamic_batch_bias_broadcast_f32 (%input: tensor <?x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >) -> () {
696+ %bias = " tosa.const" () <{values = dense <4.20 > : tensor <28 xf32 >}> : () -> tensor <28 xf32 >
697+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
698+ // CHECK: %[[DIM:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x49x42x27xf32>
699+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x45x40x28xf32>
700+ // CHECK: %[[CST:.+]] = arith.constant 4.200000e+00 : f32
701+ // CHECK: %[[BIAS:.+]] = linalg.fill
702+ // CHECK-SAME: ins(%[[CST]]
703+ // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<?x45x40x28xf32>
704+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
705+ // CHECK-SAME: outs(%[[BIAS]]
706+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
707+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf32 >}> : () -> tensor <1 xf32 >
708+ %0 = tosa.conv2d %input , %weights , %bias , %input_zp , %weight_zp {acc_type = f32 , pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <?x49 x42 x27 xf32 >, tensor <28 x3 x3 x27 xf32 >, tensor <28 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <?x45 x40 x28 xf32 >
709+ return
710+ }
711+
712+ // -----
713+
714+ // CHECK-LABEL: @conv2d_bias_broadcast_i8_acc_i32
715+ func.func @conv2d_bias_broadcast_i8_acc_i32 (%input: tensor <1 x49 x42 x27 xi8 >, %weights: tensor <28 x3 x3 x27 xi8 >) -> () {
716+ %bias = " tosa.const" () <{values = dense <42 > : tensor <28 xi8 >}> : () -> tensor <28 xi8 >
717+ // CHECK-DAG: %[[CST:.+]] = arith.constant 42 : i32
718+ // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
719+ // CHECK: %[[BIAS:.+]] = linalg.fill
720+ // CHECK-SAME: ins(%[[CST]]
721+ // CHECK-SAME: outs(%[[EMPTY]]{{.+}} -> tensor<1x45x40x28xi32>
722+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc
723+ // CHECK-SAME: outs(%[[BIAS]]
724+ %input_zp = " tosa.const" () <{values = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
725+ %weight_zp = " tosa.const" () <{values = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
726+ %0 = tosa.conv2d %input , %weights , %bias , %input_zp , %weight_zp {acc_type = i32 , pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <1 x49 x42 x27 xi8 >, tensor <28 x3 x3 x27 xi8 >, tensor <28 xi8 >, tensor <1 xi8 >, tensor <1 xi8 >) -> tensor <1 x45 x40 x28 xi32 >
727+ return
728+ }
729+
730+ // -----
731+
675732// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
676733// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
677734
0 commit comments