@@ -757,10 +757,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
757757}
758758
759759// -----
760- // CHECK-LABEL: scatter
761- func.func @test_scatter (%arg0: tensor <13 x 52 x 3 x f32 >, %arg1: tensor <13 x 26 x i32 >, %arg2: tensor < 13 x 26 x 3 x f32 > ) -> tensor <13 x 52 x 3 x f32 > {
762- %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x 52 x 3 x f32 >, tensor <13 x 26 x i32 >, tensor < 13 x 26 x 3 x f32 > ) -> tensor <13 x 52 x 3 x f32 >
763- return %0 : tensor <13 x 52 x 3 x f32 >
760+ // CHECK-LABEL: gather_int64
761+ func.func @test_gather_int64 (%arg0: tensor <13 x 21 x 3 x f32 >, %arg1: tensor <13 x 26 x i64 > ) -> tensor <13 x 26 x 3 x f32 > {
762+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x 21 x 3 x f32 >, tensor <13 x 26 x i64 > ) -> tensor <13 x 26 x 3 x f32 >
763+ return %0 : tensor <13 x 26 x 3 x f32 >
764764}
765765
766766// -----
@@ -770,6 +770,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
770770 return %0 : tensor <13 x26 x3 xf32 >
771771}
772772
773+ // -----
774+ // CHECK-LABEL: scatter
775+ func.func @test_scatter (%arg0: tensor <13 x52 x3 xf32 >, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf32 >) -> tensor <13 x52 x3 xf32 > {
776+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x52 x3 xf32 >, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf32 >) -> tensor <13 x52 x3 xf32 >
777+ return %0 : tensor <13 x52 x3 xf32 >
778+ }
779+
780+ // -----
781+ // CHECK-LABEL: scatter_int64
782+ func.func @test_scatter_int64 (%arg0: tensor <13 x52 x3 xf32 >, %arg1: tensor <13 x26 xi64 >, %arg2: tensor <13 x26 x3 xf32 >) -> tensor <13 x52 x3 xf32 > {
783+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x52 x3 xf32 >, tensor <13 x26 xi64 >, tensor <13 x26 x3 xf32 >) -> tensor <13 x52 x3 xf32 >
784+ return %0 : tensor <13 x52 x3 xf32 >
785+ }
786+
773787// -----
774788// CHECK-LABEL: scatter_unranked_indices
775789func.func @test_scatter_unranked_indices (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <*xi32 >, %arg2: tensor <13 x21 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
0 commit comments