From 92b2303105e13e7a6bcfe804ee0fcc44577c56a7 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Fri, 14 Nov 2025 14:57:04 +0000 Subject: [PATCH] [mlir][tosa] Fix scatter duplicate indices check for int64 This commit fixes the validation check for duplicate indices in the TOSA scatter operation when using int64 index tensors. Previously, use of int64 index tensors would cause a crash. Change-Id: Ib234ad655d382863cc1fcb31877190d0d20d455e --- mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp | 9 +++++---- mlir/test/Dialect/Tosa/invalid.mlir | 12 +++++++++++- .../Tosa/tosa-validation-version-1p1-valid.mlir | 10 ++++++++++ 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index ac5d6207259eb..62c015a85ee36 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -216,22 +216,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) { bool mlir::tosa::hasUniqueConstantScatterIndices( ShapedType indicesType, DenseIntElementsAttr indicesAttr) { - llvm::ArrayRef const indicesShape = indicesType.getShape(); + const llvm::ArrayRef indicesShape = indicesType.getShape(); const unsigned int indicesRank = indicesShape.size(); const unsigned int lastDimSize = indicesShape[indicesRank - 1]; // check each batch of indices from the flat indicesAttr values // for duplicates - auto const indicesValues = indicesAttr.getValues(); + auto const indicesValues = indicesAttr.getValues(); assert( (indicesValues.size() % lastDimSize == 0) && "Constant indices data length should be a multiple of indicesShape[-1]"); - std::vector indices(lastDimSize); + std::vector indices(lastDimSize); for (auto beg = indicesValues.begin(); beg < indicesValues.end(); beg += lastDimSize) { std::copy(beg, beg + lastDimSize, indices.begin()); - std::sort(indices.begin(), indices.end()); + std::sort(indices.begin(), indices.end(), + [](const APInt &a, const APInt &b) { return a.slt(b); }); if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) { // found duplicate values in indices in batch return false; diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index c9e03ca53a729..3d24928487ed2 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment" func.func @test_cast(%arg0: tensor) -> tensor<5xi32> { @@ -2044,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens // ----- +// CHECK-LABEL: test_scatter_duplicate_indices_int64 +func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64> + // expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}} + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +} + +// ----- + func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> { // expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}} %0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32> diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir index acbff73b8b948..c285ae3cf44ee 100644 --- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir +++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir @@ -2,6 +2,7 @@ // ----- +// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> { %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> @@ -146,3 +147,12 @@ func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64 %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64> return %0 : tensor<12x16xi64> } + +// ----- + +// CHECK-LABEL: test_scatter_const_indices_int64 +func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> { + %indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64> + %0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32> + return %0 : tensor<2x52x3xf32> +}