@@ -1964,3 +1964,26 @@ func.func @test_cast_fp32_i64(%arg0: tensor<1xf32>) -> (tensor<1xi64>) {
19641964 %0 = tosa.cast %arg0 : (tensor <1 xf32 >) -> tensor <1 xi64 >
19651965 return %0: tensor <1 xi64 >
19661966}
1967+
1968+ // -----
1969+
1970+ // CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d0, 0)>
1971+ // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (0, 0)>
1972+ // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
1973+
1974+ // CHECK-LABEL: func.func @test_add_0d_broadcast(
1975+ // CHECK-SAME: %[[ARG0:.*]]: tensor<2x1xf32>,
1976+ // CHECK-SAME: %[[ARG1:.*]]: tensor<f32>) -> tensor<2x1xf32> {
1977+ // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[ARG1]] [] output_shape [1, 1] : tensor<f32> into tensor<1x1xf32>
1978+ // CHECK: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<2x1xf32>
1979+ // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]], %[[EXPANDED]] : tensor<2x1xf32>, tensor<1x1xf32>) outs(%[[EMPTY_TENSOR]] : tensor<2x1xf32>) {
1980+ // CHECK: ^bb0(%[[IN0:.*]]: f32, %[[IN1:.*]]: f32, %[[OUT:.*]]: f32):
1981+ // CHECK: %[[ADD:.*]] = arith.addf %[[IN0]], %[[IN1]] : f32
1982+ // CHECK: linalg.yield %[[ADD]] : f32
1983+ // CHECK: } -> tensor<2x1xf32>
1984+ // CHECK: return %[[RESULT]] : tensor<2x1xf32>
1985+ // CHECK: }
1986+ func.func @test_add_0d_broadcast (%arg0: tensor <2 x1 xf32 >, %arg1: tensor <f32 >) -> tensor <2 x1 xf32 > {
1987+ %0 = tosa.add %arg0 , %arg1 : (tensor <2 x1 xf32 >, tensor <f32 >) -> tensor <2 x1 xf32 >
1988+ return %0 : tensor <2 x1 xf32 >
1989+ }
0 commit comments