Skip to content

Commit 7f0053a

Browse files
authored
More spmd fixes (#1843)
* More spmd fixes * fix
1 parent 9b15b0f commit 7f0053a

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

patches/xla_spmd2.patch

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
diff --git a/xla/service/spmd/spmd_partitioner.cc b/xla/service/spmd/spmd_partitioner.cc
2+
index 54fdd9a0fe..08c5f900a5 100644
3+
--- a/xla/service/spmd/spmd_partitioner.cc
4+
+++ b/xla/service/spmd/spmd_partitioner.cc
5+
@@ -570,6 +570,17 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(
6+
return PartitionedHlo(partitioned, base_shape_, state_);
7+
}
8+
9+
+ if (state_.module->config().debug_options().xla_enable_enzyme_comms_opt()) {
10+
+ if (hlo_->opcode() == HloOpcode::kBroadcast && hlo_->operand(0)->shape().dimensions().size() == 0 &&
11+
+ hlo_->operand(0)->IsConstant()) {
12+
+ HloInstruction* new_broadcast =
13+
+ state_.b->AddInstruction(HloInstruction::CreateBroadcast(
14+
+ hlo_->shape(), hlo_->mutable_operand(0), {}));
15+
+ new_broadcast->set_sharding(target);
16+
+ return PartitionedHlo(new_broadcast, base_shape_, state_);
17+
+ }
18+
+ }
19+
+
20+
if (CanReshardWithCollectivePermute(sharding(), target)) {
21+
return ReshardWithCollectivePermute(target);
22+
}

third_party/xla/workspace.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ def repo(extra_patches = [], override_commit = ""):
1717
strip_prefix = "openxla-xla-{commit}".format(commit = commit[:7]),
1818
urls = ["https://api.github.com/repos/openxla/xla/tarball/{commit}".format(commit = commit)],
1919
patch_cmds = XLA_PATCHES + extra_patches,
20-
patches = ["//:patches/xla.patch", "//:patches/xla_spmd.patch"],
20+
patches = ["//:patches/xla.patch", "//:patches/xla_spmd.patch", "//:patches/xla_spmd2.patch"],
2121
patch_args = ["-p1"],
2222
)

0 commit comments

Comments
 (0)