Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,11 @@ bool getConstShapeValues(Operation *op,
// returns a small vector of int64_t values that attr contains
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
const int rank);

// returns true iff constant indices for scatter op contains unique indices
// per batch
bool hasUniqueConstantScatterIndices(ShapedType indicesType,
DenseIntElementsAttr indicesAttr);
} // namespace tosa
} // namespace mlir

Expand Down
28 changes: 27 additions & 1 deletion mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1244,10 +1244,36 @@ bool checkErrorIfCondIf(Operation *op) {
return true;
}

bool checkErrorIfScatter(Operation *op) {
auto scatterOp = dyn_cast<tosa::ScatterOp>(op);
if (!scatterOp)
return true;

// for constant indices, check that there are no duplicate values
DenseIntElementsAttr indicesAttr;
if (!matchPattern(scatterOp.getIndices(), m_Constant(&indicesAttr)))
return true;

auto const indicesType =
dyn_cast<ShapedType>(scatterOp.getIndices().getType());
if (!indicesType || !indicesType.hasRank()) {
op->emitOpError("expect ranked indices tensor");
return false;
}

if (!hasUniqueConstantScatterIndices(indicesType, indicesAttr)) {
op->emitOpError("indices values contain duplicates");
return false;
}

return true;
}

LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) {
if (!checkErrorIfResize(op) || !checkErrorIfMul(op) ||
!checkErrorIfTable(op) || !checkErrorIfRescale(op) ||
!checkErrorIfPad(op) || !checkErrorIfCondIf(op))
!checkErrorIfPad(op) || !checkErrorIfCondIf(op) ||
!checkErrorIfScatter(op))
return failure();
return success();
}
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,30 @@ mlir::tosa::convertFromIntAttr(const DenseElementsAttr &attr, const int rank) {
}
return {};
}

bool mlir::tosa::hasUniqueConstantScatterIndices(
ShapedType indicesType, DenseIntElementsAttr indicesAttr) {
llvm::ArrayRef<int64_t> const indicesShape = indicesType.getShape();
const unsigned int indicesRank = indicesShape.size();
const unsigned int lastDimSize = indicesShape[indicesRank - 1];

// check each batch of indices from the flat indicesAttr values
// for duplicates
auto const indicesValues = indicesAttr.getValues<int32_t>();
assert(
(indicesValues.size() % lastDimSize == 0) &&
"Constant indices data length should be a multiple of indicesShape[-1]");

std::vector<uint64_t> indices(lastDimSize);
for (auto beg = indicesValues.begin(); beg < indicesValues.end();
beg += lastDimSize) {
std::copy(beg, beg + lastDimSize, indices.begin());
std::sort(indices.begin(), indices.end());
if (std::adjacent_find(indices.begin(), indices.end()) != indices.end()) {
// found duplicate values in indices in batch
return false;
}
}

return true;
}
10 changes: 10 additions & 0 deletions mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2015,3 +2015,13 @@ func.func @test_rescale_output_unsigned(%arg0: tensor<1x1xi8>) -> (tensor<1x1xui
%r = tosa.rescale %arg0, %1, %0, %3, %2 {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1xui8>
return %r : tensor<1x1xui8>
}

// -----

// CHECK-LABEL: test_scatter_duplicate_indices
func.func @test_scatter_duplicate_indices(%arg0: tensor<2x52x3xf32>, %arg2: tensor<2x12x3xf32>) -> tensor<2x52x3xf32> {
%indices = "tosa.const"() { values = dense<[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [1, 2, 3, 4, 5, 6, 7, 8, 9, 3, 11, 12]]> : tensor<2x12xi32> } : () -> tensor<2x12xi32>
// expected-error@+1 {{'tosa.scatter' op indices values contain duplicates}}
%0 = tosa.scatter %arg0, %indices, %arg2 : (tensor<2x52x3xf32>, tensor<2x12xi32>, tensor<2x12x3xf32>) -> tensor<2x52x3xf32>
return %0 : tensor<2x52x3xf32>
}
Loading