Skip to content

Conversation

@lhutton1
Copy link
Contributor

This commit fixes the validation check for duplicate indices in the TOSA scatter operation when using int64 index tensors. Previously, use of int64 index tensors would cause a crash.

This commit fixes the validation check for duplicate indices
in the TOSA scatter operation when using int64 index tensors.
Previously, use of int64 index tensors would cause a crash.

Change-Id: Ib234ad655d382863cc1fcb31877190d0d20d455e
@llvmbot llvmbot added the mlir label Nov 14, 2025
@lhutton1 lhutton1 requested a review from psunn November 14, 2025 16:40
@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit fixes the validation check for duplicate indices in the TOSA scatter operation when using int64 index tensors. Previously, use of int64 index tensors would cause a crash.


Full diff: https://github.com/llvm/llvm-project/pull/168085.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp (+5-4)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+11-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+10)
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index ac5d6207259eb..62c015a85ee36 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -216,22 +216,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
 
 bool mlir::tosa::hasUniqueConstantScatterIndices(
     ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
-  llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
+  const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape();
   const unsigned int indicesRank = indicesShape.size();
   const unsigned int lastDimSize = indicesShape[indicesRank - 1];
 
   // check each batch of indices from the flat indicesAttr values
   // for duplicates
-  auto const indicesValues = indicesAttr.getValues<int32_t>();
+  auto const indicesValues = indicesAttr.getValues<APInt>();
   assert(
       (indicesValues.size() % lastDimSize == 0) &&
       "Constant indices data length should be a multiple of indicesShape[-1]");
 
-  std::vector<uint64_t> indices(lastDimSize);
+  std::vector<APInt> indices(lastDimSize);
   for (auto beg = indicesValues.begin(); beg < indicesValues.end();
        beg += lastDimSize) {
     std::copy(beg, beg + lastDimSize, indices.begin());
-    std::sort(indices.begin(), indices.end());
+    std::sort(indices.begin(), indices.end(),
+              [](const APInt &a, const APInt &b) { return a.slt(b); });
     if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
       // found duplicate values in indices in batch
       return false;
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index c9e03ca53a729..3d24928487ed2 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -4,7 +4,7 @@
 // validation flow.
 //--------------------------------------------------------------------------------------------------
 
-// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
 
 
 func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
@@ -2044,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
 
 // -----
 
+// CHECK-LABEL: test_scatter_duplicate_indices_int64
+func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+  %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
+  // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
+  %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+  return %0 : tensor<2x52x3xf32>
+}
+
+// -----
+
 func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
   // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
   %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index acbff73b8b948..c285ae3cf44ee 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -2,6 +2,7 @@
 
 // -----
 
+// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands
 func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
   %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
   %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
@@ -146,3 +147,12 @@ func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64
   %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64>
   return %0 : tensor<12x16xi64>
 }
+
+// -----
+
+// CHECK-LABEL: test_scatter_const_indices_int64
+func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
+  %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
+  %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
+  return %0 : tensor<2x52x3xf32>
+}

@lhutton1
Copy link
Contributor Author

cc @Tai78641 @IanTaylerLessa-arm

@Tai78641
Copy link
Contributor

LGTM

@IanTaylerLessa-arm
Copy link
Contributor

Looks good, thanks!

@lhutton1 lhutton1 merged commit 70b7958 into llvm:main Nov 14, 2025
13 checks passed
@lhutton1 lhutton1 deleted the fix-duplicate-indices-check branch November 14, 2025 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants