1111from collections import defaultdict , OrderedDict
1212from dataclasses import dataclass , field
1313from functools import partial
14+ from itertools import zip_longest
1415from typing import (
1516 Any ,
1617 cast ,
5657from torchrec .distributed .sharding .cw_sharding import CwPooledEmbeddingSharding
5758from torchrec .distributed .sharding .dp_sharding import DpPooledEmbeddingSharding
5859from torchrec .distributed .sharding .dynamic_sharding import (
59- get_largest_dims_from_sharding_plan_updates ,
60- move_sharded_tensors_to_cpu ,
61- shards_all_to_all ,
60+ CommP2PMetadata ,
61+ CommStrategy ,
62+ prepare_comm_ops ,
63+ transfer_data ,
6264 update_module_sharding_plan ,
63- update_optimizer_state_post_resharding ,
64- update_state_post_resharding ,
65+ update_state_dictionaries ,
6566)
6667from torchrec .distributed .sharding .grid_sharding import GridPooledEmbeddingSharding
6768from torchrec .distributed .sharding .rw_sharding import RwPooledEmbeddingSharding
@@ -1378,24 +1379,10 @@ def _init_mean_pooling_callback(
13781379 device = self ._device ,
13791380 )
13801381
1381- def _purge_lookups (self ) -> None :
1382- # Purge old lookups
1383- for lookup in self ._lookups :
1384- # Call purge method if available (for TBE modules)
1385- if hasattr (lookup , "purge" ) and callable (lookup .purge ):
1386- # Pyre-ignore
1387- lookup .purge ()
1388-
1389- # For DDP modules, get the underlying module
1390- while isinstance (lookup , DistributedDataParallel ):
1391- lookup = lookup .module
1392- if hasattr (lookup , "purge" ) and callable (lookup .purge ):
1393- lookup .purge ()
1394-
1395- # Clear the lookups list
1382+ def _softcopy_lookups (self ) -> List [nn .Module ]:
1383+ old_modules : List [nn .Module ] = [lookup for lookup in self ._lookups ]
13961384 self ._lookups .clear ()
1397- # Force garbage collection to free memory
1398- torch .cuda .empty_cache ()
1385+ return old_modules
13991386
14001387 def _create_lookups (
14011388 self ,
@@ -1727,77 +1714,58 @@ def update_shards(
17271714 device : Optional [torch .device ],
17281715 ) -> None :
17291716 """
1730- This is the main API used in sharder.reshard, currently only support redistribution
1731- of existing shards (across different ranks, ideally from hot ranks to cold ranks)
1732- Update shards for this module based on the changed_sharding_params. This will:
1733- 1. Move current lookup tensors to CPU
1734- 2. Purge lookups
1735- 3. Call shards_all_2_all containing collective to redistribute tensors
1736- 4. Update state_dict and other attributes to reflect new placements and shards
1737- 5. Create new lookups, and load in updated state_dict
1717+ Updates the sharded embedding module in place based on the changed_sharding_params,
1718+ which contains the new ParameterSharding with different shard placements.
1719+
1720+ This method handles resharding of embedding tables, optimizer state transfer,
1721+ and updates the internal lookup and distribution modules to reflect the new sharding.
17381722
17391723 Args:
17401724 changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
17411725 table names to their new parameter sharding configs. This should only
17421726 contain shards/table names that need to be moved.
1743- env (ShardingEnv): The sharding environment for the module .
1727+ env (ShardingEnv): The sharding environment.
17441728 device (Optional[torch.device]): The device to place the updated module on.
1729+
1730+ Returns:
1731+ None
1732+ Raises:
1733+ RuntimeError: If DTensor output is enabled, as resharding is not yet supported for DTensor.
17451734 """
17461735 if env .output_dtensor :
17471736 raise RuntimeError ("We do not yet support DTensor for resharding yet" )
17481737 return
17491738
17501739 current_state = self .state_dict ()
1751- current_state = move_sharded_tensors_to_cpu ( current_state )
1752- # TODO: improve, checking one would be enough
1740+
1741+ # Check if local optimizer state exists and is non-empty for all optimizers.
17531742 has_local_optimizer = len (self ._optim ._optims ) > 0 and all (
17541743 len (i ) > 0 for i in self ._optim .state_dict ()["state" ].values ()
17551744 )
17561745
1757- # communicate optimizer state across all ranks, because if one rank owns all tables
1758- # and other ranks does not own any table, and later transfer the weights to empty rank
1759- # creates inconsistent state, because initally empty rank does not have optimizer state
1760- # hence, incorrectly computes the tensor splits
1761-
1746+ # Communicate optimizer state across all ranks to ensure consistency.
17621747 has_optimizer = self ._is_optimizer_enabled (has_local_optimizer , env , device )
17631748
1764- # TODO: make sure this is clearing all lookups
1765- self ._purge_lookups ()
1749+ # Save old lookup modules for cleanup.
1750+ old_lookups : List [ nn . Module ] = self ._softcopy_lookups ()
17661751
1767- # Get max dim size to enable padding for all_to_all
1768- max_dim_0 , max_dim_1 = get_largest_dims_from_sharding_plan_updates (
1769- changed_sharding_params
1770- )
1752+ # Save old optimizer state if present.
17711753 old_optimizer_state = self ._optim .state_dict () if has_local_optimizer else None
1772- if old_optimizer_state is not None :
1773- move_sharded_tensors_to_cpu (old_optimizer_state )
17741754
1775- local_shard_names_by_src_rank , local_output_tensor_cpu = shards_all_to_all (
1776- module = self ,
1777- state_dict = current_state ,
1778- device = device , # pyre-ignore
1779- changed_sharding_params = changed_sharding_params ,
1780- env = env ,
1781- extend_shard_name = self .extend_shard_name ,
1782- max_dim_0 = max_dim_0 ,
1783- max_dim_1 = max_dim_1 ,
1784- optimizer_state = old_optimizer_state ,
1785- has_optimizer = has_optimizer ,
1786- )
1755+ assert hasattr (self , "module_sharding_plan" )
1756+ current_module_sharding_plan = copy .deepcopy (self .module_sharding_plan )
17871757
1788- for name , param in changed_sharding_params .items ():
1789- self .module_sharding_plan [name ] = param
1790- # TODO: Support detecting old sharding type when sharding type is changing
1791- for sharding_info in self .sharding_type_to_sharding_infos [
1792- param .sharding_type
1793- ]:
1794- if sharding_info .embedding_config .name == name :
1795- sharding_info .param_sharding = param
1758+ # Update the module sharding plan with the changed sharding parameters.
1759+ update_module_sharding_plan (
1760+ self , changed_sharding_params , self .sharding_type_to_sharding_infos
1761+ )
17961762
17971763 self ._sharding_types : List [str ] = list (
17981764 self .sharding_type_to_sharding_infos .keys ()
17991765 )
18001766 # TODO: Optimize to update only the changed embedding shardings
1767+
1768+ # Recreate embedding sharding modules based on the new sharding infos.
18011769 self ._embedding_shardings : List [
18021770 EmbeddingSharding [
18031771 EmbeddingShardingContext ,
@@ -1816,7 +1784,7 @@ def update_shards(
18161784 for embedding_configs in self .sharding_type_to_sharding_infos .values ()
18171785 ]
18181786
1819- # Reset input dists
1787+ # Reset input distribution and feature ordering.
18201788 self ._has_uninitialized_input_dist = True
18211789 self ._input_dists : List [nn .Module ] = []
18221790 self ._features_order : List [int ] = []
@@ -1825,15 +1793,16 @@ def update_shards(
18251793 self ._create_lookups ()
18261794 self ._update_output_dist ()
18271795
1796+ # Re-initialize torch state if in a distributed environment.
18281797 if env .process_group and dist .get_backend (env .process_group ) != "fake" :
18291798 self ._initialize_torch_state (skip_registering = True )
18301799
1831- # update optimizer
1800+ # Update optimizer to reflect new parameters.
18321801 optims = []
18331802 for lookup in self ._lookups :
18341803 for _ , tbe_module in lookup .named_modules ():
18351804 if isinstance (tbe_module , FusedOptimizerModule ):
1836- # modify param keys to match EmbeddingBagCollection
1805+ # Modify param keys to match EmbeddingBagCollection
18371806 params : Mapping [str , Union [torch .Tensor , ShardedTensor ]] = {}
18381807 for (
18391808 param_key ,
@@ -1845,31 +1814,82 @@ def update_shards(
18451814 optims .append (("" , tbe_module .fused_optimizer ))
18461815
18471816 self ._optim : CombinedOptimizer = CombinedOptimizer (optims )
1817+ new_state = self .state_dict ()
18481818
1849- if has_optimizer :
1850- optimizer_state = update_optimizer_state_post_resharding (
1851- old_opt_state = old_optimizer_state , # pyre-ignore
1852- new_opt_state = self ._optim .state_dict (),
1853- ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1854- output_tensor = local_output_tensor_cpu ,
1855- max_dim_0 = max_dim_0 ,
1819+ optimizer_state : Dict [str , Dict [str , Dict [str , Any ]]] = self ._optim .state_dict ()
1820+
1821+ # Prepare and execute communication operations for state transfer.
1822+ shard_keys = list (changed_sharding_params .keys ())
1823+ comms_op : Dict [CommStrategy , List [CommP2PMetadata ]] = {}
1824+ reqs : List [Tuple [dist .Work , CommP2PMetadata ]] = []
1825+ # Pipeline for communication and computation overlapping
1826+ # move shards of current table while loading next table shards for communiucation
1827+ for i , (shard_name , nxt_shard_name ) in enumerate (
1828+ zip_longest (shard_keys , shard_keys [1 :])
1829+ ):
1830+ if i == 0 :
1831+ # Prepare communication P2P operations
1832+ comms_op = prepare_comm_ops (
1833+ module_sharding_plan = current_module_sharding_plan ,
1834+ current_state_dict = current_state ,
1835+ new_state_dict = new_state ,
1836+ changed_sharding_params = changed_sharding_params ,
1837+ shard_name = shard_name ,
1838+ env = env ,
1839+ current_opt_state = old_optimizer_state ,
1840+ new_opt_state = optimizer_state ,
1841+ extend_shard_name = self .extend_shard_name ,
1842+ has_optimizer = has_optimizer ,
1843+ )
1844+
1845+ if comms_op :
1846+ # call underlying batch_isend_irecv primitives
1847+ reqs = transfer_data (comms_op = comms_op )
1848+
1849+ if nxt_shard_name :
1850+ comms_op = prepare_comm_ops (
1851+ module_sharding_plan = current_module_sharding_plan ,
1852+ current_state_dict = current_state ,
1853+ new_state_dict = new_state ,
1854+ changed_sharding_params = changed_sharding_params ,
1855+ shard_name = nxt_shard_name ,
1856+ env = env ,
1857+ current_opt_state = old_optimizer_state ,
1858+ new_opt_state = optimizer_state ,
1859+ extend_shard_name = self .extend_shard_name ,
1860+ has_optimizer = has_optimizer ,
1861+ )
1862+ else :
1863+ break
1864+ # Update state and optimizer states
1865+ update_state_dictionaries (
1866+ reqs = reqs ,
1867+ old_optimizer_state = old_optimizer_state ,
1868+ new_optimizer_state = optimizer_state ,
1869+ old_state = current_state ,
1870+ new_state = new_state ,
1871+ changed_sharding_params = changed_sharding_params ,
18561872 extend_shard_name = self .extend_shard_name ,
18571873 )
1858- self ._optim .load_state_dict (optimizer_state )
18591874
1860- new_state = self .state_dict ()
1861- current_state = update_state_post_resharding (
1875+ update_state_dictionaries (
1876+ reqs = reqs ,
1877+ old_optimizer_state = old_optimizer_state ,
1878+ new_optimizer_state = optimizer_state ,
18621879 old_state = current_state ,
18631880 new_state = new_state ,
1864- ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1865- output_tensor = local_output_tensor_cpu ,
1881+ changed_sharding_params = changed_sharding_params ,
18661882 extend_shard_name = self .extend_shard_name ,
1867- has_optimizer = has_optimizer ,
1883+ update_local = True ,
18681884 )
18691885
1870- self .load_state_dict (current_state )
1871-
1872- update_module_sharding_plan (self , changed_sharding_params )
1886+ # Clean up old lookup modules.
1887+ for lookup in old_lookups :
1888+ del lookup
1889+ old_lookups .clear ()
1890+ self .load_state_dict (new_state , assign = True )
1891+ if has_optimizer :
1892+ self ._optim .load_state_dict (optimizer_state )
18731893 return
18741894
18751895 def create_rocksdb_hard_link_snapshot (self ) -> None :
0 commit comments