Skip to content

Commit 8723fe5

Browse files
authored
[mlir][tosa] Allow int64 index tensors in gather/scatter (#167894)
This commit ensures that gather and scatter operations with int64 index tensors can be created. This aligns with the EXT_INT64 extension.
1 parent 31b7f1f commit 8723fe5

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,7 +2271,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
22712271

22722272
let arguments = (ins
22732273
Tosa_Tensor3D:$values,
2274-
Tosa_Int32Tensor2D:$indices
2274+
Tosa_IndexTensor2D:$indices
22752275
);
22762276

22772277
let results = (outs
@@ -2308,7 +2308,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
23082308

23092309
let arguments = (ins
23102310
Tosa_Tensor3D:$values_in,
2311-
Tosa_Int32Tensor2D:$indices,
2311+
Tosa_IndexTensor2D:$indices,
23122312
Tosa_Tensor3D:$input
23132313
);
23142314

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,8 @@ def Tosa_Tensor1Dto6D : AnyTypeOf<[
202202
def Tosa_TensorUpto4D : AnyTypeOf<[
203203
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
204204

205-
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
206-
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
207-
def Tosa_Int32Tensor2D : AnyTypeOf<[
208-
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [2]>]>;
205+
def Tosa_IndexTensor2D : AnyTypeOf<[
206+
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32, Tosa_Int64], [2]>]>;
209207

210208
def Tosa_TensorAtLeast1D : AnyTypeOf<[
211209
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -757,10 +757,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
757757
}
758758

759759
// -----
760-
// CHECK-LABEL: scatter
761-
func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
762-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
763-
return %0 : tensor<13x52x3xf32>
760+
// CHECK-LABEL: gather_int64
761+
func.func @test_gather_int64(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi64>) -> tensor<13x26x3xf32> {
762+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi64>) -> tensor<13x26x3xf32>
763+
return %0 : tensor<13x26x3xf32>
764764
}
765765

766766
// -----
@@ -770,6 +770,20 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
770770
return %0 : tensor<13x26x3xf32>
771771
}
772772

773+
// -----
774+
// CHECK-LABEL: scatter
775+
func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
776+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
777+
return %0 : tensor<13x52x3xf32>
778+
}
779+
780+
// -----
781+
// CHECK-LABEL: scatter_int64
782+
func.func @test_scatter_int64(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi64>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
783+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi64>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
784+
return %0 : tensor<13x52x3xf32>
785+
}
786+
773787
// -----
774788
// CHECK-LABEL: scatter_unranked_indices
775789
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)