Skip to content

Commit 1f8540e

Browse files
ZixuanJiangGoogle-ML-Automation
authored andcommitted
Create copy if the operands of gather/scatter instructions overlap.
A gather has two operands, input and indices. If they point to the same instruction, create a copy for indices. A scatter has n inputs, 1 indices, and n updates (2n+1 operands in total). We allow overlap between n inputs. We also allow overlap between n updates. We need to create a copy if * indices overlap with any input or update * update overlap with any input The added copy will be removed if it is redundant in the following memory related passes (e.g., CopyInsertion). PiperOrigin-RevId: 715164959
1 parent b27c9d3 commit 1f8540e

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

xla/service/spmd/gather_scatter_handler.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,12 @@ absl::Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
957957
auto gather = Cast<HloGatherInstruction>(hlo);
958958
const auto& dnums = gather->gather_dimension_numbers();
959959
auto operand = GetPartitionedHlo(gather->operand(0));
960-
auto indices = GetPartitionedHlo(gather->operand(1));
960+
auto raw_indices = GetPartitionedHlo(gather->operand(1));
961+
auto indices =
962+
(operand.hlo() == raw_indices.hlo())
963+
? MakeACopyAndReturnItsPartitionedHlo(raw_indices, builder())
964+
: raw_indices;
965+
961966
std::vector<int64_t> batch_dims;
962967
for (int64_t i = 0; i < gather->shape().rank(); ++i) {
963968
if (!absl::c_linear_search(dnums.offset_dims(), i)) {
@@ -1822,6 +1827,13 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
18221827
absl::c_transform(
18231828
scatter->scatter_updates(), std::back_inserter(updates),
18241829
[this](HloInstruction* hlo) { return GetPartitionedHlo(hlo); });
1830+
for (PartitionedHlo& update : updates) {
1831+
if (absl::c_any_of(operands, [&](const PartitionedHlo& operand) {
1832+
return update.hlo() == operand.hlo();
1833+
})) {
1834+
update = MakeACopyAndReturnItsPartitionedHlo(update, builder());
1835+
}
1836+
}
18251837
if (!absl::c_all_of(updates, [&](const PartitionedHlo& update) {
18261838
return update.sharding() == updates[0].sharding() &&
18271839
update.base_shape() == updates[0].base_shape();
@@ -1847,6 +1859,15 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
18471859
? scatter_reduction_root->shape().tuple_shapes_size()
18481860
: 1);
18491861
auto indices = GetPartitionedHlo(scatter->scatter_indices());
1862+
if (absl::c_any_of(operands,
1863+
[&](const PartitionedHlo& operand) {
1864+
return indices.hlo() == operand.hlo();
1865+
}) ||
1866+
absl::c_any_of(updates, [&](const PartitionedHlo& update) {
1867+
return indices.hlo() == update.hlo();
1868+
})) {
1869+
indices = MakeACopyAndReturnItsPartitionedHlo(indices, builder());
1870+
}
18501871
auto indices_sharding = indices.sharding();
18511872
// Reshard indices with -1 padding, which will have no effect on the result as
18521873
// guaranteed by the scatter semantics.

xla/service/spmd/spmd_partitioner_test.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14871,6 +14871,33 @@ ENTRY %main.21 {
1487114871
EXPECT_THAT(updates, op::Shape("bf16[4096,64]"));
1487214872
}
1487314873

14874+
TEST_P(SpmdPartitioningTest, ScatterAllOperandsAreSameInstruction) {
14875+
const char* const hlo_string = R"(
14876+
HloModule pjit
14877+
14878+
%s32_add {
14879+
a = s32[] parameter(0)
14880+
b = s32[] parameter(1)
14881+
ROOT result = s32[] add(a, b)
14882+
}
14883+
14884+
ENTRY %main.21 {
14885+
p0 = s32[8,64] parameter(0), sharding={devices=[4,1]<=[4]}
14886+
ROOT scatter = s32[8,64] scatter(p0, p0, p0), update_window_dims={},
14887+
input_batching_dims={0}, scatter_indices_batching_dims={0},
14888+
inserted_window_dims={1}, scatter_dims_to_operand_dims={1},
14889+
index_vector_dim=2, to_apply=s32_add, sharding={devices=[4,1]<=[4]}
14890+
})";
14891+
14892+
TF_ASSERT_OK_AND_ASSIGN(auto module,
14893+
PartitionComputation(hlo_string, /*num_devices=*/4));
14894+
14895+
auto p0 = AllOf(op::Shape("s32[2,64]"), op::Parameter(0));
14896+
auto p0_copy = AllOf(op::Shape("s32[2,64]"), op::Copy(p0));
14897+
EXPECT_THAT(module->entry_computation()->root_instruction(),
14898+
AllOf(op::Shape("s32[2,64]"), op::Scatter(p0, p0_copy, p0_copy)));
14899+
}
14900+
1487414901
TEST_P(SpmdPartitioningTest, ComplexReshardUnmerge) {
1487514902
const char* const hlo_string = R"(
1487614903
HloModule Test

0 commit comments

Comments
 (0)