Skip to content

Commit 70b7958

Browse files
authored
[mlir][tosa] Fix scatter duplicate indices check for int64 (#168085)
This commit fixes the validation check for duplicate indices in the TOSA scatter operation when using int64 index tensors. Previously, use of int64 index tensors would cause a crash.
1 parent f7a8d20 commit 70b7958

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
216216

217217
bool mlir::tosa::hasUniqueConstantScatterIndices(
218218
ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
219-
llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
219+
const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape();
220220
const unsigned int indicesRank = indicesShape.size();
221221
const unsigned int lastDimSize = indicesShape[indicesRank - 1];
222222

223223
// check each batch of indices from the flat indicesAttr values
224224
// for duplicates
225-
auto const indicesValues = indicesAttr.getValues<int32_t>();
225+
auto const indicesValues = indicesAttr.getValues<APInt>();
226226
assert(
227227
(indicesValues.size() % lastDimSize == 0) &&
228228
"Constant indices data length should be a multiple of indicesShape[-1]");
229229

230-
std::vector<uint64_t> indices(lastDimSize);
230+
std::vector<APInt> indices(lastDimSize);
231231
for (auto beg = indicesValues.begin(); beg < indicesValues.end();
232232
beg += lastDimSize) {
233233
std::copy(beg, beg + lastDimSize, indices.begin());
234-
std::sort(indices.begin(), indices.end());
234+
std::sort(indices.begin(), indices.end(),
235+
[](const APInt &a, const APInt &b) { return a.slt(b); });
235236
if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
236237
// found duplicate values in indices in batch
237238
return false;

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// validation flow.
55
//--------------------------------------------------------------------------------------------------
66

7-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
7+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
88

99

1010
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
@@ -2044,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
20442044

20452045
// -----
20462046

2047+
// CHECK-LABEL: test_scatter_duplicate_indices_int64
2048+
func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
2049+
%indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
2050+
// expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
2051+
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
2052+
return %0 : tensor<2x52x3xf32>
2053+
}
2054+
2055+
// -----
2056+
20472057
func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
20482058
// expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
20492059
%0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>

mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
// -----
44

5+
// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands
56
func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
67
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
78
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
@@ -146,3 +147,12 @@ func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64
146147
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64>
147148
return %0 : tensor<12x16xi64>
148149
}
150+
151+
// -----
152+
153+
// CHECK-LABEL: test_scatter_const_indices_int64
154+
func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
155+
%indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
156+
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
157+
return %0 : tensor<2x52x3xf32>
158+
}

0 commit comments

Comments
 (0)