@@ -15399,38 +15399,83 @@ struct ScatterIndicesAreUnique
1539915399 if (shape.empty())
1540015400 return failure();
1540115401
15402+ auto dimNumbers = op.getScatterDimensionNumbers();
15403+ int64_t indexVectorDim = dimNumbers.getIndexVectorDim();
15404+
1540215405 int64_t numTuples = 1;
15403- for (int64_t i = 0; i < shape.size() - 1; ++i) {
15404- numTuples *= shape[i];
15406+ for (int64_t i = 0; i < shape.size(); ++i) {
15407+ if (i != indexVectorDim) {
15408+ numTuples *= shape[i];
15409+ }
15410+ }
15411+ int64_t tupleSize = shape[indexVectorDim];
15412+
15413+ SmallVector<int64_t> strides(shape.size());
15414+ strides[shape.size() - 1] = 1;
15415+ for (int64_t i = shape.size() - 2; i >= 0; --i) {
15416+ strides[i] = strides[i + 1] * shape[i + 1];
15417+ }
15418+
15419+ SmallVector<int64_t> nonIndexVectorShape;
15420+ for (int64_t i = 0; i < shape.size(); ++i) {
15421+ if (i != indexVectorDim) {
15422+ nonIndexVectorShape.push_back(shape[i]);
15423+ }
1540515424 }
15406- int64_t tupleSize = shape.back();
1540715425
1540815426 // Iterate over the scatter indices tensor to extract tuples
1540915427 SmallVector<SmallVector<int64_t>> indexTuples;
1541015428 auto values = denseAttr.getValues<APInt>();
15411- auto it = values.begin();
15412- for (int64_t i = 0; i < numTuples; ++i) {
15413- SmallVector<int64_t> indexTuple;
15414- for (int64_t j = 0; j < tupleSize; ++j) {
15415- if (it == values.end()) {
15416- return failure(); // Unexpected end of values
15417- }
15418- indexTuple.push_back((*it).getSExtValue());
15419- ++it;
15420- }
15421- indexTuples.push_back(indexTuple);
15422- }
1542315429
15424- if (areIndexTuplesUnique(indexTuples)) {
15425- auto newOp = rewriter.create<stablehlo::ScatterOp>(
15426- op.getLoc(), op.getResultTypes(), op.getInputs(),
15427- op.getScatterIndices(), op.getUpdates(),
15428- op.getScatterDimensionNumbers(), op.getIndicesAreSortedAttr(),
15429- rewriter.getBoolAttr(true));
15430- newOp.getUpdateComputation().takeBody(op.getUpdateComputation());
15431- rewriter.replaceOp(op, newOp);
15432- return success();
15433- }
15430+ std::function<void(SmallVector<int64_t>, int64_t)> extractTuples =
15431+ [&](SmallVector<int64_t> currentIndices, int64_t dim) {
15432+ if (dim == nonIndexVectorShape.size()) {
15433+ SmallVector<int64_t> indexTuple;
15434+ for (int64_t component = 0; component < tupleSize; ++component) {
15435+ // Build full multi-dimensional index
15436+ SmallVector<int64_t> fullIndex;
15437+ int64_t nonIndexDim = 0;
15438+ for (int64_t d = 0; d < shape.size(); ++d) {
15439+ if (d == indexVectorDim) {
15440+ fullIndex.push_back(component);
15441+ } else {
15442+ fullIndex.push_back(currentIndices[nonIndexDim++]);
15443+ }
15444+ }
15445+
15446+ // Convert to linear index
15447+ int64_t linearIdx = 0;
15448+ for (int64_t d = 0; d < shape.size(); ++d) {
15449+ linearIdx += fullIndex[d] * strides[d];
15450+ }
15451+
15452+ auto it = values.begin();
15453+ std::advance(it, linearIdx);
15454+ indexTuple.push_back((*it).getSExtValue());
15455+ }
15456+ indexTuples.push_back(indexTuple);
15457+ return;
15458+ }
15459+
15460+ for (int64_t i = 0; i < nonIndexVectorShape[dim]; ++i) {
15461+ SmallVector<int64_t> newIndices = currentIndices;
15462+ newIndices.push_back(i);
15463+ extractTuples(newIndices, dim + 1);
15464+ }
15465+ };
15466+
15467+ extractTuples({}, 0);
15468+
15469+ bool uniqueIndices = areIndexTuplesUnique(indexTuples);
15470+ if (!uniqueIndices && !op.getUniqueIndices())
15471+ return failure();
15472+ auto newOp = rewriter.create<stablehlo::ScatterOp>(
15473+ op.getLoc(), op.getResultTypes(), op.getInputs(), scatterIndices,
15474+ op.getUpdates(), dimNumbers, op.getIndicesAreSortedAttr(),
15475+ rewriter.getBoolAttr(uniqueIndices));
15476+ newOp.getUpdateComputation().takeBody(op.getUpdateComputation());
15477+ rewriter.replaceOp(op, newOp);
15478+ return success();
1543415479 }
1543515480
1543615481 return failure();
@@ -15439,17 +15484,13 @@ struct ScatterIndicesAreUnique
1543915484private:
1544015485 bool areIndexTuplesUnique(
1544115486 const SmallVector<SmallVector<int64_t>> &indexTuples) const {
15442- bool hasUnique = true;
15443- for (int64_t i = 0; i < indexTuples.size() && hasUnique; ++i) {
15444- for (int64_t j = i + 1; j < indexTuples.size() && hasUnique; ++j) {
15445- if (std::equal(indexTuples[i].begin(), indexTuples[i].end(),
15446- indexTuples[j].begin(), indexTuples[j].end())) {
15447- hasUnique = false;
15448- break;
15449- }
15487+ std::set<SmallVector<int64_t>> uniqueSet;
15488+ for (const auto &tuple : indexTuples) {
15489+ if (!uniqueSet.insert(tuple).second) {
15490+ return false; // Duplicate found
1545015491 }
1545115492 }
15452- return hasUnique ;
15493+ return true ;
1545315494 }
1545415495};
1545515496
0 commit comments