Skip to content

Commit 2e7cb97

Browse files
ZixuanJiangGoogle-ML-Automation
authored andcommitted
Moving the logic of making a copy for rhs from PartitionDot to HandleDotHelper.
`HandleDotHelper` is called once for a single dot operation, while `PartitionDot` can be called many times. We need to consider adding a copy only once. PiperOrigin-RevId: 715189518
1 parent 7d912e5 commit 2e7cb97

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

xla/service/spmd/dot_handler.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4114,7 +4114,7 @@ absl::StatusOr<HloInstruction*> PartitionDotRemovingOutputPartialReplication(
41144114
// in the operands and output, group the devices and recursively partition
41154115
// the in-group dot.
41164116
absl::StatusOr<HloInstruction*> PartitionDot(
4117-
const PartitionedHlo& lhs, const PartitionedHlo& raw_rhs,
4117+
const PartitionedHlo& lhs, const PartitionedHlo& rhs,
41184118
const Shape& output_base_shape, const HloSharding& output_sharding,
41194119
const DotConvolutionDimsInfo& dims_mapping, int64_t num_partitions,
41204120
absl::FunctionRef<absl::StatusOr<HloInstruction*>(
@@ -4127,12 +4127,6 @@ absl::StatusOr<HloInstruction*> PartitionDot(
41274127
std::vector<SpmdPartitioningVisitor::WindowedDotGeneralLoop>*
41284128
windowed_dot_general_loops,
41294129
SpmdPartitioningVisitor* visitor) {
4130-
// If lhs' hlo and rhs' hlo are identical, make a copy for rhs.
4131-
const PartitionedHlo& rhs =
4132-
(lhs.hlo() == raw_rhs.hlo())
4133-
? MakeACopyAndReturnItsPartitionedHlo(raw_rhs, b)
4134-
: raw_rhs;
4135-
41364130
// Recursively partition on different types of dimensions.
41374131

41384132
// Case 0: Try partition the purely spatially-partitioned convolution with
@@ -4306,8 +4300,14 @@ absl::Status SpmdPartitioningVisitor::HandleDotHelper(
43064300
if (hlo->sharding().HasUniqueDevice()) {
43074301
return DefaultAction(hlo);
43084302
}
4309-
auto& lhs = GetPartitionedHlo(hlo->operand(0));
4310-
auto& rhs = GetPartitionedHlo(hlo->operand(1));
4303+
PartitionedHlo& lhs = GetPartitionedHlo(hlo->operand(0));
4304+
PartitionedHlo& raw_rhs = GetPartitionedHlo(hlo->operand(1));
4305+
// If lhs and rhs are the same instruction, make a copy for rhs.
4306+
const PartitionedHlo& rhs =
4307+
(lhs.hlo() == raw_rhs.hlo())
4308+
? MakeACopyAndReturnItsPartitionedHlo(raw_rhs, builder())
4309+
: raw_rhs;
4310+
43114311
Window conv_window;
43124312
if (hlo->opcode() == HloOpcode::kConvolution) {
43134313
conv_window = hlo->window();

0 commit comments

Comments
 (0)