Skip to content

Commit 644b188

Browse files
ZixuanJiangGoogle-ML-Automation
authored andcommitted
Refactor PartitionedHlo::ReshardWithAllToAll without behavior change.
PiperOrigin-RevId: 707525624
1 parent 2c70c13 commit 644b188

File tree

1 file changed

+82
-116
lines changed

1 file changed

+82
-116
lines changed

xla/service/spmd/spmd_partitioner.cc

Lines changed: 82 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,113 +1593,77 @@ PartitionedHlo PartitionedHlo::Broadcast() const {
15931593
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
15941594
const HloSharding& target,
15951595
absl::Span<const std::pair<int64_t, int64_t>> source_target_dims) const {
1596+
if (target == sharding()) {
1597+
return *this;
1598+
}
1599+
VLOG(5) << "Source: " << sharding().ToString();
1600+
VLOG(5) << "Target: " << target.ToString();
15961601
if (source_target_dims.empty()) {
1597-
if (target == sharding()) {
1598-
return *this;
1599-
}
16001602
// If the device order is different in the target, fix the order with
16011603
// ReshardWithCollectivePermute.
16021604
return ReshardWithCollectivePermute(target);
16031605
}
16041606

1605-
VLOG(5) << "Source: " << sharding().ToString();
1606-
VLOG(5) << "Target: " << target.ToString();
16071607
// Swap one pair of dimensions.
1608-
int64_t source_dim = source_target_dims[0].first;
1609-
int64_t target_dim = source_target_dims[0].second;
1608+
const int64_t source_dim = source_target_dims[0].first;
1609+
const int64_t target_dim = source_target_dims[0].second;
1610+
VLOG(5) << "Source dim: " << source_dim;
1611+
VLOG(5) << "Target dim: " << target_dim;
1612+
CHECK_NE(source_dim, target_dim);
16101613
const int64_t group_size = sharding().tile_assignment().dim(source_dim) /
16111614
sharding().tile_assignment().dim(target_dim);
1612-
16131615
VLOG(5) << "Group size: " << group_size;
1614-
auto temp_target_tile = [&] {
1615-
auto& original_tile_assignment = sharding().tile_assignment();
1616-
std::vector<int64_t> reshape_tile_dims(
1617-
original_tile_assignment.num_dimensions() + 2);
1618-
int64_t i = 0;
1619-
int64_t added_source_dim = -1;
1620-
int64_t added_target_dim = -1;
1621-
for (int64_t j = 0; j < original_tile_assignment.num_dimensions(); ++j) {
1622-
if (source_dim == j) {
1623-
reshape_tile_dims[i] = original_tile_assignment.dim(j) / group_size;
1624-
reshape_tile_dims[++i] = group_size;
1625-
added_source_dim = i;
1626-
} else if (target_dim == j) {
1627-
reshape_tile_dims[i] = original_tile_assignment.dim(j);
1628-
reshape_tile_dims[++i] = 1;
1629-
added_target_dim = i;
1630-
} else {
1631-
reshape_tile_dims[i] = original_tile_assignment.dim(j);
1632-
}
1633-
++i;
1634-
}
1635-
VLOG(5) << "Added target: " << added_target_dim;
1636-
VLOG(5) << "Added source: " << added_source_dim;
1637-
std::vector<int64_t> xpose_dims(reshape_tile_dims.size());
1638-
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
1639-
xpose_dims[added_source_dim] = added_target_dim;
1640-
xpose_dims[added_target_dim] = added_source_dim;
1641-
auto temp_target_tile =
1642-
hlo_sharding_util::TransposeSharding(
1643-
HloSharding::Tile(
1644-
original_tile_assignment.Reshape(reshape_tile_dims)),
1645-
xpose_dims)
1646-
.tile_assignment();
1647-
VLOG(5) << "Transposed target: " << temp_target_tile.ToString();
1648-
std::vector<int64_t> temp_target_tile_dims(
1649-
sharding().tile_assignment().dimensions().begin(),
1650-
sharding().tile_assignment().dimensions().end());
1651-
temp_target_tile_dims[source_dim] =
1652-
sharding().tile_assignment().dim(target_dim);
1653-
temp_target_tile_dims[target_dim] =
1654-
sharding().tile_assignment().dim(source_dim);
1655-
return temp_target_tile.Reshape(temp_target_tile_dims);
1656-
}();
1616+
1617+
std::vector<int64_t> reshape_tile_dims;
1618+
reshape_tile_dims.reserve(sharding().tile_assignment().num_dimensions() + 2);
1619+
int64_t added_source_dim;
1620+
int64_t added_target_dim;
1621+
for (int64_t j = 0; j < sharding().tile_assignment().num_dimensions(); ++j) {
1622+
if (source_dim == j) {
1623+
reshape_tile_dims.push_back(sharding().tile_assignment().dim(j) /
1624+
group_size);
1625+
reshape_tile_dims.push_back(group_size);
1626+
added_source_dim = reshape_tile_dims.size() - 1;
1627+
} else if (target_dim == j) {
1628+
reshape_tile_dims.push_back(sharding().tile_assignment().dim(j));
1629+
reshape_tile_dims.push_back(1);
1630+
added_target_dim = reshape_tile_dims.size() - 1;
1631+
} else {
1632+
reshape_tile_dims.push_back(sharding().tile_assignment().dim(j));
1633+
}
1634+
}
1635+
VLOG(5) << "Added target: " << added_target_dim;
1636+
VLOG(5) << "Added source: " << added_source_dim;
1637+
std::vector<int> xpose_dims(reshape_tile_dims.size());
1638+
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
1639+
std::swap(xpose_dims[added_source_dim], xpose_dims[added_target_dim]);
1640+
std::vector<int64_t> temp_target_tile_dims(
1641+
sharding().tile_assignment().dimensions().begin(),
1642+
sharding().tile_assignment().dimensions().end());
1643+
std::swap(temp_target_tile_dims[source_dim],
1644+
temp_target_tile_dims[target_dim]);
1645+
auto temp_target_tile = sharding()
1646+
.tile_assignment()
1647+
.Reshape(reshape_tile_dims)
1648+
.Transpose(xpose_dims)
1649+
.Reshape(temp_target_tile_dims);
16571650
auto temp_target = target.ReplicateOnLastTileDim()
16581651
? HloSharding::PartialTile(temp_target_tile)
16591652
: HloSharding::Tile(temp_target_tile);
16601653
VLOG(5) << "Temp target sharding: " << temp_target.ToString();
1661-
auto padded_shape = hlo_->shape();
1662-
auto padded_base_shape = base_shape_;
1663-
auto current_base_padded_shape = base_shape_;
1664-
padded_base_shape.set_dimensions(
1665-
target_dim, RoundUpTo(base_shape_.dimensions(target_dim),
1666-
temp_target.tile_assignment().dim(target_dim)));
1667-
current_base_padded_shape.set_dimensions(
1668-
target_dim, hlo_->shape().dimensions(target_dim) *
1669-
sharding().tile_assignment().dim(target_dim));
1670-
1671-
auto padded_source_base_shape = base_shape_;
1672-
auto current_source_base_padded_shape = base_shape_;
1673-
padded_source_base_shape.set_dimensions(
1674-
source_dim, RoundUpTo(base_shape_.dimensions(source_dim),
1675-
temp_target.tile_assignment().dim(source_dim)));
1676-
current_source_base_padded_shape.set_dimensions(
1677-
source_dim, hlo_->shape().dimensions(source_dim) *
1678-
sharding().tile_assignment().dim(source_dim));
1679-
1680-
VLOG(5) << "Target dim: " << target_dim;
1681-
VLOG(5) << "Source dim: " << source_dim;
1682-
VLOG(5) << "Original sharded shape: " << hlo_->shape();
1683-
VLOG(5) << "Base shape: " << base_shape_.ToString();
1684-
VLOG(5) << "Padded base shape: " << padded_base_shape.ToString();
1685-
VLOG(5) << "Current padded shape: " << current_base_padded_shape.ToString();
1686-
VLOG(5) << "Padded source base shape: "
1687-
<< padded_source_base_shape.ToString();
1688-
VLOG(5) << "Current source padded shape: "
1689-
<< current_source_base_padded_shape.ToString();
1690-
VLOG(5) << "Dimension padded target_dim: "
1691-
<< hlo_->shape().dimensions(target_dim) *
1692-
sharding().tile_assignment().dim(target_dim);
1693-
CHECK_GE(padded_base_shape.rank(), current_base_padded_shape.rank());
1694-
CHECK_LE(padded_source_base_shape.rank(),
1695-
current_source_base_padded_shape.rank());
16961654

16971655
PaddingConfig pc;
16981656
for (int64_t i = 0; i < hlo_->shape().rank(); ++i) {
16991657
auto* pd = pc.add_dimensions();
17001658
pd->set_edge_padding_low(0);
1701-
pd->set_edge_padding_high(padded_base_shape.dimensions(i) -
1702-
current_base_padded_shape.dimensions(i));
1659+
if (i == target_dim) {
1660+
pd->set_edge_padding_high(
1661+
RoundUpTo(base_shape_.dimensions(i),
1662+
temp_target.tile_assignment().dim(i)) -
1663+
hlo_->shape().dimensions(i) * sharding().tile_assignment().dim(i));
1664+
} else {
1665+
pd->set_edge_padding_high(0);
1666+
}
17031667
pd->set_interior_padding(0);
17041668
}
17051669
PartitionedHlo p_hlo = *this;
@@ -1734,27 +1698,16 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
17341698
groups[group_id].push_back(device);
17351699
});
17361700

1737-
HloInstruction* result = nullptr;
1738-
1739-
// Split along the split dimension (target_dim) of the all-to-all
1740-
// output.
1741-
std::vector<int64_t> dimensions;
1742-
const int64_t rank = base_shape_.rank();
1743-
dimensions.reserve(rank + 1);
1744-
for (int64_t i = 0; i < rank; ++i) {
1745-
if (i == target_dim) {
1746-
dimensions.push_back(group_size);
1747-
dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
1748-
} else {
1749-
dimensions.push_back(padded_hlo->shape().dimensions(i));
1750-
}
1751-
}
1752-
VLOG(5) << "Target ata shape: "
1753-
<< ShapeUtil::MakeShape(base_shape_.element_type(), dimensions)
1754-
.ToString();
1701+
// Split along the split dimension (target_dim) of the all-to-all output.
1702+
std::vector<int64_t> target_ata_dims(padded_hlo->shape().dimensions().begin(),
1703+
padded_hlo->shape().dimensions().end());
1704+
target_ata_dims.insert(target_ata_dims.begin() + target_dim, group_size);
1705+
target_ata_dims[target_dim + 1] /= group_size;
17551706
auto reshape = state_.b->AddInstruction(HloInstruction::CreateReshape(
1756-
ShapeUtil::MakeShape(base_shape_.element_type(), dimensions),
1707+
ShapeUtil::MakeShape(base_shape_.element_type(), target_ata_dims),
17571708
padded_hlo));
1709+
VLOG(5) << "Target ata shape: " << reshape->shape().ToString();
1710+
17581711
// After the reshape, it is guaranteed to have at least 3 dimensions.
17591712
auto all_to_all =
17601713
state_.collective_ops_creator.create_cross_partition_all_to_all(
@@ -1783,27 +1736,40 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
17831736
auto new_shape = ShapeInference::InferAllToAllShape(
17841737
padded_hlo->shape(), target_dim, source_dim, group_size)
17851738
.value();
1786-
result = state_.b->AddInstruction(
1739+
HloInstruction* result = state_.b->AddInstruction(
17871740
HloInstruction::CreateReshape(new_shape, transpose));
1741+
CHECK_EQ(result->shape().rank(), base_shape_.rank());
17881742
result->set_sharding(temp_target);
1743+
1744+
auto padded_source_base_shape = base_shape_;
1745+
auto current_source_base_padded_shape = base_shape_;
1746+
padded_source_base_shape.set_dimensions(
1747+
source_dim, RoundUpTo(base_shape_.dimensions(source_dim),
1748+
temp_target.tile_assignment().dim(source_dim)));
1749+
current_source_base_padded_shape.set_dimensions(
1750+
source_dim, hlo_->shape().dimensions(source_dim) *
1751+
sharding().tile_assignment().dim(source_dim));
1752+
1753+
VLOG(5) << "Original sharded shape: " << hlo_->shape();
1754+
VLOG(5) << "Base shape: " << base_shape_.ToString();
1755+
VLOG(5) << "Padded source base shape: "
1756+
<< padded_source_base_shape.ToString();
1757+
VLOG(5) << "Current source padded shape: "
1758+
<< current_source_base_padded_shape.ToString();
1759+
17891760
std::vector<int64_t> strides(result->shape().rank(), 1);
17901761
std::vector<int64_t> starts(result->shape().rank(), 0);
1791-
std::vector<int64_t> limits(result->shape().rank());
1792-
for (int64_t i = 0; i < result->shape().rank(); ++i) {
1793-
limits[i] = padded_source_base_shape.dimensions(i);
1794-
}
17951762
auto sliced_phlo = ReshardDataForSlicing(
1796-
strides, starts, limits,
1763+
strides, starts, padded_source_base_shape.dimensions(),
17971764
PartitionedHlo(result, current_source_base_padded_shape, state_),
17981765
temp_target, state_.b);
17991766
CHECK(sliced_phlo.has_value());
18001767
result = SliceDataFromWindowReshard(*sliced_phlo, strides, base_shape_,
18011768
temp_target, state_.b);
18021769
result->set_sharding(temp_target);
1803-
auto remaining_source_target_dims = source_target_dims;
1804-
remaining_source_target_dims.remove_prefix(1);
18051770
return PartitionedHlo(result, base_shape_, state_)
1806-
.ReshardWithAllToAll(target, remaining_source_target_dims);
1771+
.ReshardWithAllToAll(
1772+
target, source_target_dims.last(source_target_dims.size() - 1));
18071773
}
18081774

18091775
namespace {

0 commit comments

Comments
 (0)