Skip to content

Commit 92b2303

Browse files
committed
[mlir][tosa] Fix scatter duplicate indices check for int64
This commit fixes the validation check for duplicate indices in the TOSA scatter operation when using int64 index tensors. Previously, use of int64 index tensors would cause a crash. Change-Id: Ib234ad655d382863cc1fcb31877190d0d20d455e
1 parent 7ee0e0f commit 92b2303

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,23 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
216216

217217
bool mlir::tosa::hasUniqueConstantScatterIndices(
218218
ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
219-
llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
219+
const llvm::ArrayRef<int64_t> indicesShape = indicesType.getShape();
220220
const unsigned int indicesRank = indicesShape.size();
221221
const unsigned int lastDimSize = indicesShape[indicesRank - 1];
222222

223223
// check each batch of indices from the flat indicesAttr values
224224
// for duplicates
225-
auto const indicesValues = indicesAttr.getValues<int32_t>();
225+
auto const indicesValues = indicesAttr.getValues<APInt>();
226226
assert(
227227
(indicesValues.size() % lastDimSize == 0) &&
228228
"Constant indices data length should be a multiple of indicesShape[-1]");
229229

230-
std::vector<uint64_t> indices(lastDimSize);
230+
std::vector<APInt> indices(lastDimSize);
231231
for (auto beg = indicesValues.begin(); beg < indicesValues.end();
232232
beg += lastDimSize) {
233233
std::copy(beg, beg + lastDimSize, indices.begin());
234-
std::sort(indices.begin(), indices.end());
234+
std::sort(indices.begin(), indices.end(),
235+
[](const APInt &a, const APInt &b) { return a.slt(b); });
235236
if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
236237
// found duplicate values in indices in batch
237238
return false;

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// validation flow.
55
//--------------------------------------------------------------------------------------------------
66

7-
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="profiles=pro_int,pro_fp extensions=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
7+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-attach-target="specification_version=1.1.draft profiles=pro_int,pro_fp extensions=int16,int4,int64,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow,doubleround,inexactround" -tosa-validate="strict-op-spec-alignment"
88

99

1010
func.func @test_cast(%arg0: tensor<i1>) -> tensor<5xi32> {
@@ -2044,6 +2044,16 @@ func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tens
20442044

20452045
// -----
20462046

2047+
// CHECK-LABEL: test_scatter_duplicate_indices_int64
2048+
func.func @test_scatter_duplicate_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
2049+
%indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
2050+
// expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
2051+
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
2052+
return %0 : tensor<2x52x3xf32>
2053+
}
2054+
2055+
// -----
2056+
20472057
func.func @test_reduce_all_unsupported_data_types(%arg0: tensor<2x12x11xf32>) -> tensor<1x12x11xf32> {
20482058
// expected-error@+1 {{'tosa.reduce_all' op illegal: operation operand/result data types did not align with any profile or extension, got (f32,f32), did you mean (i1,i1)?}}
20492059
%0 = tosa.reduce_all %arg0 {axis = 0 : i32} : (tensor<2x12x11xf32>) -> tensor<1x12x11xf32>

mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
// -----
44

5+
// CHECK-LABEL: test_matmul_fp8_mixed_precision_operands
56
func.func @test_matmul_fp8_mixed_precision_operands(%arg0: tensor<1x14x19xf8E4M3FN>, %arg1: tensor<1x19x28xf8E5M2>) -> tensor<1x14x28xf16> {
67
%azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN>
78
%bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2>
@@ -146,3 +147,12 @@ func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64
146147
%0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64>
147148
return %0 : tensor<12x16xi64>
148149
}
150+
151+
// -----
152+
153+
// CHECK-LABEL: test_scatter_const_indices_int64
154+
func.func @test_scatter_const_indices_int64(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
155+
%indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]> : tensor<2x12xi64> } : () -> tensor<2x12xi64>
156+
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi64>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
157+
return %0 : tensor<2x52x3xf32>
158+
}

0 commit comments

Comments
 (0)