@@ -1593,113 +1593,77 @@ PartitionedHlo PartitionedHlo::Broadcast() const {
15931593PartitionedHlo PartitionedHlo::ReshardWithAllToAll (
15941594 const HloSharding& target,
15951595 absl::Span<const std::pair<int64_t , int64_t >> source_target_dims) const {
1596+ if (target == sharding ()) {
1597+ return *this ;
1598+ }
1599+ VLOG (5 ) << " Source: " << sharding ().ToString ();
1600+ VLOG (5 ) << " Target: " << target.ToString ();
15961601 if (source_target_dims.empty ()) {
1597- if (target == sharding ()) {
1598- return *this ;
1599- }
16001602 // If the device order is different in the target, fix the order with
16011603 // ReshardWithCollectivePermute.
16021604 return ReshardWithCollectivePermute (target);
16031605 }
16041606
1605- VLOG (5 ) << " Source: " << sharding ().ToString ();
1606- VLOG (5 ) << " Target: " << target.ToString ();
16071607 // Swap one pair of dimensions.
1608- int64_t source_dim = source_target_dims[0 ].first ;
1609- int64_t target_dim = source_target_dims[0 ].second ;
1608+ const int64_t source_dim = source_target_dims[0 ].first ;
1609+ const int64_t target_dim = source_target_dims[0 ].second ;
1610+ VLOG (5 ) << " Source dim: " << source_dim;
1611+ VLOG (5 ) << " Target dim: " << target_dim;
1612+ CHECK_NE (source_dim, target_dim);
16101613 const int64_t group_size = sharding ().tile_assignment ().dim (source_dim) /
16111614 sharding ().tile_assignment ().dim (target_dim);
1612-
16131615 VLOG (5 ) << " Group size: " << group_size;
1614- auto temp_target_tile = [&] {
1615- auto & original_tile_assignment = sharding ().tile_assignment ();
1616- std::vector<int64_t > reshape_tile_dims (
1617- original_tile_assignment.num_dimensions () + 2 );
1618- int64_t i = 0 ;
1619- int64_t added_source_dim = -1 ;
1620- int64_t added_target_dim = -1 ;
1621- for (int64_t j = 0 ; j < original_tile_assignment.num_dimensions (); ++j) {
1622- if (source_dim == j) {
1623- reshape_tile_dims[i] = original_tile_assignment.dim (j) / group_size;
1624- reshape_tile_dims[++i] = group_size;
1625- added_source_dim = i;
1626- } else if (target_dim == j) {
1627- reshape_tile_dims[i] = original_tile_assignment.dim (j);
1628- reshape_tile_dims[++i] = 1 ;
1629- added_target_dim = i;
1630- } else {
1631- reshape_tile_dims[i] = original_tile_assignment.dim (j);
1632- }
1633- ++i;
1634- }
1635- VLOG (5 ) << " Added target: " << added_target_dim;
1636- VLOG (5 ) << " Added source: " << added_source_dim;
1637- std::vector<int64_t > xpose_dims (reshape_tile_dims.size ());
1638- std::iota (xpose_dims.begin (), xpose_dims.end (), 0 );
1639- xpose_dims[added_source_dim] = added_target_dim;
1640- xpose_dims[added_target_dim] = added_source_dim;
1641- auto temp_target_tile =
1642- hlo_sharding_util::TransposeSharding (
1643- HloSharding::Tile (
1644- original_tile_assignment.Reshape (reshape_tile_dims)),
1645- xpose_dims)
1646- .tile_assignment ();
1647- VLOG (5 ) << " Transposed target: " << temp_target_tile.ToString ();
1648- std::vector<int64_t > temp_target_tile_dims (
1649- sharding ().tile_assignment ().dimensions ().begin (),
1650- sharding ().tile_assignment ().dimensions ().end ());
1651- temp_target_tile_dims[source_dim] =
1652- sharding ().tile_assignment ().dim (target_dim);
1653- temp_target_tile_dims[target_dim] =
1654- sharding ().tile_assignment ().dim (source_dim);
1655- return temp_target_tile.Reshape (temp_target_tile_dims);
1656- }();
1616+
1617+ std::vector<int64_t > reshape_tile_dims;
1618+ reshape_tile_dims.reserve (sharding ().tile_assignment ().num_dimensions () + 2 );
1619+ int64_t added_source_dim;
1620+ int64_t added_target_dim;
1621+ for (int64_t j = 0 ; j < sharding ().tile_assignment ().num_dimensions (); ++j) {
1622+ if (source_dim == j) {
1623+ reshape_tile_dims.push_back (sharding ().tile_assignment ().dim (j) /
1624+ group_size);
1625+ reshape_tile_dims.push_back (group_size);
1626+ added_source_dim = reshape_tile_dims.size () - 1 ;
1627+ } else if (target_dim == j) {
1628+ reshape_tile_dims.push_back (sharding ().tile_assignment ().dim (j));
1629+ reshape_tile_dims.push_back (1 );
1630+ added_target_dim = reshape_tile_dims.size () - 1 ;
1631+ } else {
1632+ reshape_tile_dims.push_back (sharding ().tile_assignment ().dim (j));
1633+ }
1634+ }
1635+ VLOG (5 ) << " Added target: " << added_target_dim;
1636+ VLOG (5 ) << " Added source: " << added_source_dim;
1637+ std::vector<int > xpose_dims (reshape_tile_dims.size ());
1638+ std::iota (xpose_dims.begin (), xpose_dims.end (), 0 );
1639+ std::swap (xpose_dims[added_source_dim], xpose_dims[added_target_dim]);
1640+ std::vector<int64_t > temp_target_tile_dims (
1641+ sharding ().tile_assignment ().dimensions ().begin (),
1642+ sharding ().tile_assignment ().dimensions ().end ());
1643+ std::swap (temp_target_tile_dims[source_dim],
1644+ temp_target_tile_dims[target_dim]);
1645+ auto temp_target_tile = sharding ()
1646+ .tile_assignment ()
1647+ .Reshape (reshape_tile_dims)
1648+ .Transpose (xpose_dims)
1649+ .Reshape (temp_target_tile_dims);
16571650 auto temp_target = target.ReplicateOnLastTileDim ()
16581651 ? HloSharding::PartialTile (temp_target_tile)
16591652 : HloSharding::Tile (temp_target_tile);
16601653 VLOG (5 ) << " Temp target sharding: " << temp_target.ToString ();
1661- auto padded_shape = hlo_->shape ();
1662- auto padded_base_shape = base_shape_;
1663- auto current_base_padded_shape = base_shape_;
1664- padded_base_shape.set_dimensions (
1665- target_dim, RoundUpTo (base_shape_.dimensions (target_dim),
1666- temp_target.tile_assignment ().dim (target_dim)));
1667- current_base_padded_shape.set_dimensions (
1668- target_dim, hlo_->shape ().dimensions (target_dim) *
1669- sharding ().tile_assignment ().dim (target_dim));
1670-
1671- auto padded_source_base_shape = base_shape_;
1672- auto current_source_base_padded_shape = base_shape_;
1673- padded_source_base_shape.set_dimensions (
1674- source_dim, RoundUpTo (base_shape_.dimensions (source_dim),
1675- temp_target.tile_assignment ().dim (source_dim)));
1676- current_source_base_padded_shape.set_dimensions (
1677- source_dim, hlo_->shape ().dimensions (source_dim) *
1678- sharding ().tile_assignment ().dim (source_dim));
1679-
1680- VLOG (5 ) << " Target dim: " << target_dim;
1681- VLOG (5 ) << " Source dim: " << source_dim;
1682- VLOG (5 ) << " Original sharded shape: " << hlo_->shape ();
1683- VLOG (5 ) << " Base shape: " << base_shape_.ToString ();
1684- VLOG (5 ) << " Padded base shape: " << padded_base_shape.ToString ();
1685- VLOG (5 ) << " Current padded shape: " << current_base_padded_shape.ToString ();
1686- VLOG (5 ) << " Padded source base shape: "
1687- << padded_source_base_shape.ToString ();
1688- VLOG (5 ) << " Current source padded shape: "
1689- << current_source_base_padded_shape.ToString ();
1690- VLOG (5 ) << " Dimension padded target_dim: "
1691- << hlo_->shape ().dimensions (target_dim) *
1692- sharding ().tile_assignment ().dim (target_dim);
1693- CHECK_GE (padded_base_shape.rank (), current_base_padded_shape.rank ());
1694- CHECK_LE (padded_source_base_shape.rank (),
1695- current_source_base_padded_shape.rank ());
16961654
16971655 PaddingConfig pc;
16981656 for (int64_t i = 0 ; i < hlo_->shape ().rank (); ++i) {
16991657 auto * pd = pc.add_dimensions ();
17001658 pd->set_edge_padding_low (0 );
1701- pd->set_edge_padding_high (padded_base_shape.dimensions (i) -
1702- current_base_padded_shape.dimensions (i));
1659+ if (i == target_dim) {
1660+ pd->set_edge_padding_high (
1661+ RoundUpTo (base_shape_.dimensions (i),
1662+ temp_target.tile_assignment ().dim (i)) -
1663+ hlo_->shape ().dimensions (i) * sharding ().tile_assignment ().dim (i));
1664+ } else {
1665+ pd->set_edge_padding_high (0 );
1666+ }
17031667 pd->set_interior_padding (0 );
17041668 }
17051669 PartitionedHlo p_hlo = *this ;
@@ -1734,27 +1698,16 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
17341698 groups[group_id].push_back (device);
17351699 });
17361700
1737- HloInstruction* result = nullptr ;
1738-
1739- // Split along the split dimension (target_dim) of the all-to-all
1740- // output.
1741- std::vector<int64_t > dimensions;
1742- const int64_t rank = base_shape_.rank ();
1743- dimensions.reserve (rank + 1 );
1744- for (int64_t i = 0 ; i < rank; ++i) {
1745- if (i == target_dim) {
1746- dimensions.push_back (group_size);
1747- dimensions.push_back (padded_hlo->shape ().dimensions (i) / group_size);
1748- } else {
1749- dimensions.push_back (padded_hlo->shape ().dimensions (i));
1750- }
1751- }
1752- VLOG (5 ) << " Target ata shape: "
1753- << ShapeUtil::MakeShape (base_shape_.element_type (), dimensions)
1754- .ToString ();
1701+ // Split along the split dimension (target_dim) of the all-to-all output.
1702+ std::vector<int64_t > target_ata_dims (padded_hlo->shape ().dimensions ().begin (),
1703+ padded_hlo->shape ().dimensions ().end ());
1704+ target_ata_dims.insert (target_ata_dims.begin () + target_dim, group_size);
1705+ target_ata_dims[target_dim + 1 ] /= group_size;
17551706 auto reshape = state_.b ->AddInstruction (HloInstruction::CreateReshape (
1756- ShapeUtil::MakeShape (base_shape_.element_type (), dimensions ),
1707+ ShapeUtil::MakeShape (base_shape_.element_type (), target_ata_dims ),
17571708 padded_hlo));
1709+ VLOG (5 ) << " Target ata shape: " << reshape->shape ().ToString ();
1710+
17581711 // After the reshape, it is guaranteed to have at least 3 dimensions.
17591712 auto all_to_all =
17601713 state_.collective_ops_creator .create_cross_partition_all_to_all (
@@ -1783,27 +1736,40 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
17831736 auto new_shape = ShapeInference::InferAllToAllShape (
17841737 padded_hlo->shape (), target_dim, source_dim, group_size)
17851738 .value ();
1786- result = state_.b ->AddInstruction (
1739+ HloInstruction* result = state_.b ->AddInstruction (
17871740 HloInstruction::CreateReshape (new_shape, transpose));
1741+ CHECK_EQ (result->shape ().rank (), base_shape_.rank ());
17881742 result->set_sharding (temp_target);
1743+
1744+ auto padded_source_base_shape = base_shape_;
1745+ auto current_source_base_padded_shape = base_shape_;
1746+ padded_source_base_shape.set_dimensions (
1747+ source_dim, RoundUpTo (base_shape_.dimensions (source_dim),
1748+ temp_target.tile_assignment ().dim (source_dim)));
1749+ current_source_base_padded_shape.set_dimensions (
1750+ source_dim, hlo_->shape ().dimensions (source_dim) *
1751+ sharding ().tile_assignment ().dim (source_dim));
1752+
1753+ VLOG (5 ) << " Original sharded shape: " << hlo_->shape ();
1754+ VLOG (5 ) << " Base shape: " << base_shape_.ToString ();
1755+ VLOG (5 ) << " Padded source base shape: "
1756+ << padded_source_base_shape.ToString ();
1757+ VLOG (5 ) << " Current source padded shape: "
1758+ << current_source_base_padded_shape.ToString ();
1759+
17891760 std::vector<int64_t > strides (result->shape ().rank (), 1 );
17901761 std::vector<int64_t > starts (result->shape ().rank (), 0 );
1791- std::vector<int64_t > limits (result->shape ().rank ());
1792- for (int64_t i = 0 ; i < result->shape ().rank (); ++i) {
1793- limits[i] = padded_source_base_shape.dimensions (i);
1794- }
17951762 auto sliced_phlo = ReshardDataForSlicing (
1796- strides, starts, limits ,
1763+ strides, starts, padded_source_base_shape. dimensions () ,
17971764 PartitionedHlo (result, current_source_base_padded_shape, state_),
17981765 temp_target, state_.b );
17991766 CHECK (sliced_phlo.has_value ());
18001767 result = SliceDataFromWindowReshard (*sliced_phlo, strides, base_shape_,
18011768 temp_target, state_.b );
18021769 result->set_sharding (temp_target);
1803- auto remaining_source_target_dims = source_target_dims;
1804- remaining_source_target_dims.remove_prefix (1 );
18051770 return PartitionedHlo (result, base_shape_, state_)
1806- .ReshardWithAllToAll (target, remaining_source_target_dims);
1771+ .ReshardWithAllToAll (
1772+ target, source_target_dims.last (source_target_dims.size () - 1 ));
18071773}
18081774
18091775namespace {
0 commit comments