Skip to content

Commit f7a8d0f

Browse files
authored
fix: scatter unique indices for non-last index dim (#1031)
1 parent 70f97aa commit f7a8d0f

File tree

2 files changed

+87
-34
lines changed

2 files changed

+87
-34
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15399,38 +15399,83 @@ struct ScatterIndicesAreUnique
1539915399
if (shape.empty())
1540015400
return failure();
1540115401

15402+
auto dimNumbers = op.getScatterDimensionNumbers();
15403+
int64_t indexVectorDim = dimNumbers.getIndexVectorDim();
15404+
1540215405
int64_t numTuples = 1;
15403-
for (int64_t i = 0; i < shape.size() - 1; ++i) {
15404-
numTuples *= shape[i];
15406+
for (int64_t i = 0; i < shape.size(); ++i) {
15407+
if (i != indexVectorDim) {
15408+
numTuples *= shape[i];
15409+
}
15410+
}
15411+
int64_t tupleSize = shape[indexVectorDim];
15412+
15413+
SmallVector<int64_t> strides(shape.size());
15414+
strides[shape.size() - 1] = 1;
15415+
for (int64_t i = shape.size() - 2; i >= 0; --i) {
15416+
strides[i] = strides[i + 1] * shape[i + 1];
15417+
}
15418+
15419+
SmallVector<int64_t> nonIndexVectorShape;
15420+
for (int64_t i = 0; i < shape.size(); ++i) {
15421+
if (i != indexVectorDim) {
15422+
nonIndexVectorShape.push_back(shape[i]);
15423+
}
1540515424
}
15406-
int64_t tupleSize = shape.back();
1540715425

1540815426
// Iterate over the scatter indices tensor to extract tuples
1540915427
SmallVector<SmallVector<int64_t>> indexTuples;
1541015428
auto values = denseAttr.getValues<APInt>();
15411-
auto it = values.begin();
15412-
for (int64_t i = 0; i < numTuples; ++i) {
15413-
SmallVector<int64_t> indexTuple;
15414-
for (int64_t j = 0; j < tupleSize; ++j) {
15415-
if (it == values.end()) {
15416-
return failure(); // Unexpected end of values
15417-
}
15418-
indexTuple.push_back((*it).getSExtValue());
15419-
++it;
15420-
}
15421-
indexTuples.push_back(indexTuple);
15422-
}
1542315429

15424-
if (areIndexTuplesUnique(indexTuples)) {
15425-
auto newOp = rewriter.create<stablehlo::ScatterOp>(
15426-
op.getLoc(), op.getResultTypes(), op.getInputs(),
15427-
op.getScatterIndices(), op.getUpdates(),
15428-
op.getScatterDimensionNumbers(), op.getIndicesAreSortedAttr(),
15429-
rewriter.getBoolAttr(true));
15430-
newOp.getUpdateComputation().takeBody(op.getUpdateComputation());
15431-
rewriter.replaceOp(op, newOp);
15432-
return success();
15433-
}
15430+
std::function<void(SmallVector<int64_t>, int64_t)> extractTuples =
15431+
[&](SmallVector<int64_t> currentIndices, int64_t dim) {
15432+
if (dim == nonIndexVectorShape.size()) {
15433+
SmallVector<int64_t> indexTuple;
15434+
for (int64_t component = 0; component < tupleSize; ++component) {
15435+
// Build full multi-dimensional index
15436+
SmallVector<int64_t> fullIndex;
15437+
int64_t nonIndexDim = 0;
15438+
for (int64_t d = 0; d < shape.size(); ++d) {
15439+
if (d == indexVectorDim) {
15440+
fullIndex.push_back(component);
15441+
} else {
15442+
fullIndex.push_back(currentIndices[nonIndexDim++]);
15443+
}
15444+
}
15445+
15446+
// Convert to linear index
15447+
int64_t linearIdx = 0;
15448+
for (int64_t d = 0; d < shape.size(); ++d) {
15449+
linearIdx += fullIndex[d] * strides[d];
15450+
}
15451+
15452+
auto it = values.begin();
15453+
std::advance(it, linearIdx);
15454+
indexTuple.push_back((*it).getSExtValue());
15455+
}
15456+
indexTuples.push_back(indexTuple);
15457+
return;
15458+
}
15459+
15460+
for (int64_t i = 0; i < nonIndexVectorShape[dim]; ++i) {
15461+
SmallVector<int64_t> newIndices = currentIndices;
15462+
newIndices.push_back(i);
15463+
extractTuples(newIndices, dim + 1);
15464+
}
15465+
};
15466+
15467+
extractTuples({}, 0);
15468+
15469+
bool uniqueIndices = areIndexTuplesUnique(indexTuples);
15470+
if (!uniqueIndices && !op.getUniqueIndices())
15471+
return failure();
15472+
auto newOp = rewriter.create<stablehlo::ScatterOp>(
15473+
op.getLoc(), op.getResultTypes(), op.getInputs(), scatterIndices,
15474+
op.getUpdates(), dimNumbers, op.getIndicesAreSortedAttr(),
15475+
rewriter.getBoolAttr(uniqueIndices));
15476+
newOp.getUpdateComputation().takeBody(op.getUpdateComputation());
15477+
rewriter.replaceOp(op, newOp);
15478+
return success();
1543415479
}
1543515480

1543615481
return failure();
@@ -15439,17 +15484,13 @@ struct ScatterIndicesAreUnique
1543915484
private:
1544015485
bool areIndexTuplesUnique(
1544115486
const SmallVector<SmallVector<int64_t>> &indexTuples) const {
15442-
bool hasUnique = true;
15443-
for (int64_t i = 0; i < indexTuples.size() && hasUnique; ++i) {
15444-
for (int64_t j = i + 1; j < indexTuples.size() && hasUnique; ++j) {
15445-
if (std::equal(indexTuples[i].begin(), indexTuples[i].end(),
15446-
indexTuples[j].begin(), indexTuples[j].end())) {
15447-
hasUnique = false;
15448-
break;
15449-
}
15487+
std::set<SmallVector<int64_t>> uniqueSet;
15488+
for (const auto &tuple : indexTuples) {
15489+
if (!uniqueSet.insert(tuple).second) {
15490+
return false; // Duplicate found
1545015491
}
1545115492
}
15452-
return hasUnique;
15493+
return true;
1545315494
}
1545415495
};
1545515496

test/lit_tests/scatteruniqueindices.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,15 @@ func.func @test_scatter_single(%arg0: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -
3737
}) : (tensor<2x3xf32>, tensor<1x1xi32>, tensor<1x3xf32>) -> tensor<2x3xf32>
3838
return %0 : tensor<2x3xf32>
3939
}
40+
41+
// CHECK-LABEL: func.func @test_scatter_unique2
42+
func.func @test_scatter_unique2(%arg0: tensor<5xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<5xf32>) -> tensor<5xf32> {
43+
%c = stablehlo.constant dense<[[3, 0, 0, 4, 2]]> : tensor<1x5xi64>
44+
// CHECK: %{{.+}} = "stablehlo.scatter"(%{{.+}}, %{{.+}}, %{{.+}}) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>}> ({
45+
%0 = "stablehlo.scatter"(%arg0, %c, %arg1) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0], scatter_dims_to_operand_dims = [0]>}> ({
46+
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
47+
%1 = stablehlo.add %arg2, %arg3 : tensor<f32>
48+
stablehlo.return %1 : tensor<f32>
49+
}) : (tensor<5xf32>, tensor<1x5xi64>, tensor<5xf32>) -> tensor<5xf32>
50+
return %0 : tensor<5xf32>
51+
}

0 commit comments

Comments
 (0)