diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 95913a65a..cbcd0d78b 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -456,23 +456,16 @@ def _get_sharded_local_buckets_for_zero_collision( for table in embedding_tables: total_num_buckets = none_throws(table.total_num_buckets) + assert ( + total_num_buckets % world_size == 0 + ), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}" assert ( table.total_num_buckets and table.num_embeddings % table.total_num_buckets == 0 ), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'" - extra_local_buckets = int(local_rank < (total_num_buckets % world_size)) - extra_bucket_padding = ( - (total_num_buckets % world_size) - if local_rank >= (total_num_buckets % world_size) - else 0 - ) - bucket_offset_start = ( - total_num_buckets // world_size + extra_local_buckets - ) * local_rank + extra_bucket_padding + bucket_offset_start = total_num_buckets // world_size * local_rank bucket_offset_end = min( - total_num_buckets, - (total_num_buckets // world_size + extra_local_buckets) * (local_rank + 1) - + extra_bucket_padding, + total_num_buckets, total_num_buckets // world_size * (local_rank + 1) ) bucket_size = ( table.num_embeddings + total_num_buckets - 1 diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index e444f59c8..21776d697 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -99,13 +99,9 @@ def create_virtual_table_global_metadata( # Otherwise it will only set correct size on current rank and # virtual PMT will trigger recalc for the correct global size/offset. # NOTE this currently only works for row-wise sharding - my_rank_shard_size = metadata.shards_metadata[my_rank].shard_sizes[0] for rank, shard_metadata in enumerate(metadata.shards_metadata): if use_param_size_as_rows: # respect the param size and treat it as rows - # The param size only has the information for my_rank. In order to - # correctly calculate the size for other ranks, we need to use the current - # rank's shard size compared to the shard size of my_rank. - curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16] + curr_rank_rows = param.size()[0] # pyre-ignore[16] else: curr_rank_rows = ( weight_count_per_rank[rank] if weight_count_per_rank is not None else 1 diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 202be6b71..1e3abbfcb 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -38,10 +38,6 @@ ShardingType, ) from torchrec.modules.embedding_configs import DataType -from torchrec.modules.embedding_modules import ( - EmbeddingBagCollection, - EmbeddingCollection, -) from torchrec.modules.embedding_tower import EmbeddingTower, EmbeddingTowerCollection @@ -182,7 +178,7 @@ def enumerate( # skip for other device groups if device_group and device_group != self._compute_device: continue - num_buckets = self._get_num_buckets(name, child_module) + sharding_options_per_table: List[ShardingOption] = [] for sharding_type in self._filter_sharding_types( @@ -204,7 +200,6 @@ def enumerate( sharding_type=sharding_type, col_wise_shard_dim=col_wise_shard_dim, device_memory_sizes=self._device_memory_sizes, - num_buckets=num_buckets, ) except ZeroDivisionError as e: # Re-raise with additional context about the table and module @@ -269,33 +264,6 @@ def enumerate( self._last_stored_search_space = copy.deepcopy(sharding_options) return sharding_options - def _get_num_buckets(self, parameter: str, module: nn.Module) -> Optional[int]: - """ - Get the number of buckets for each embedding table. - - Args: - parameter (str): name of the embedding table. - module (nn.Module): module to be sharded. - - Returns: - Optional[int]: Number of buckets for the table, or None if module is not EmbeddingBagCollection or table not found. - """ - # If module is not of type EmbeddingBagCollection, return None - if isinstance(module, EmbeddingBagCollection): - embedding_configs = module.embedding_bag_configs() - elif isinstance(module, EmbeddingCollection): - embedding_configs = module.embedding_configs() - else: - return None - - # Find the embedding config for the table with the same name as parameter input - for config in embedding_configs: - if config.name == parameter and config.use_virtual_table: - return config.total_num_buckets - - # If table with matching name not found, return None - return None - @property def last_stored_search_space(self) -> Optional[List[ShardingOption]]: # NOTE: This is the last search space stored by enumerate(...), do not use diff --git a/torchrec/distributed/planner/tests/test_enumerators.py b/torchrec/distributed/planner/tests/test_enumerators.py index 39a39d9f0..5adead69a 100644 --- a/torchrec/distributed/planner/tests/test_enumerators.py +++ b/torchrec/distributed/planner/tests/test_enumerators.py @@ -18,10 +18,7 @@ EmbeddingTowerSharder, ) from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.embeddingbag import ( - EmbeddingBagCollection, - EmbeddingBagCollectionSharder, -) +from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, ) @@ -48,13 +45,6 @@ [[17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [17, 80], [11, 80]], ] -EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS = [ - [[20, 20], [20, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20], [10, 20]], - [[22, 40], [22, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40], [11, 40]], - [[24, 60], [24, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60], [12, 60]], - [[26, 80], [26, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80], [13, 80]], -] - EXPECTED_RW_SHARD_OFFSETS = [ [[0, 0], [13, 0], [26, 0], [39, 0], [52, 0], [65, 0], [78, 0], [91, 0]], [[0, 0], [14, 0], [28, 0], [42, 0], [56, 0], [70, 0], [84, 0], [98, 0]], @@ -62,13 +52,6 @@ [[0, 0], [17, 0], [34, 0], [51, 0], [68, 0], [85, 0], [102, 0], [119, 0]], ] -EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS = [ - [[0, 0], [20, 0], [40, 0], [50, 0], [60, 0], [70, 0], [80, 0], [90, 0]], - [[0, 0], [22, 0], [44, 0], [55, 0], [66, 0], [77, 0], [88, 0], [99, 0]], - [[0, 0], [24, 0], [48, 0], [60, 0], [72, 0], [84, 0], [96, 0], [108, 0]], - [[0, 0], [26, 0], [52, 0], [65, 0], [78, 0], [91, 0], [104, 0], [117, 0]], -] - def get_expected_cache_aux_size(rows: int) -> int: # 0.2 is the hardcoded cache load factor assumed in this test @@ -118,48 +101,6 @@ def get_expected_cache_aux_size(rows: int) -> int: ], ] -EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS = [ - [ - Storage(hbm=165888, ddr=0), - Storage(hbm=165888, ddr=0), - Storage(hbm=165888, ddr=0), - Storage(hbm=165888, ddr=0), - Storage(hbm=165888, ddr=0), - Storage(hbm=165888, ddr=0), - Storage(hbm=165888, ddr=0), - Storage(hbm=165888, ddr=0), - ], - [ - Storage(hbm=1001472, ddr=0), - Storage(hbm=1001472, ddr=0), - Storage(hbm=1001472, ddr=0), - Storage(hbm=1001472, ddr=0), - Storage(hbm=1001472, ddr=0), - Storage(hbm=1001472, ddr=0), - Storage(hbm=1001472, ddr=0), - Storage(hbm=1001472, ddr=0), - ], - [ - Storage(hbm=1003520, ddr=0), - Storage(hbm=1003520, ddr=0), - Storage(hbm=1003520, ddr=0), - Storage(hbm=1003520, ddr=0), - Storage(hbm=1003520, ddr=0), - Storage(hbm=1003520, ddr=0), - Storage(hbm=1003520, ddr=0), - Storage(hbm=1003520, ddr=0), - ], - [ - Storage(hbm=2648064, ddr=0), - Storage(hbm=2648064, ddr=0), - Storage(hbm=2648064, ddr=0), - Storage(hbm=2648064, ddr=0), - Storage(hbm=2648064, ddr=0), - Storage(hbm=2648064, ddr=0), - Storage(hbm=2648064, ddr=0), - Storage(hbm=2648064, ddr=0), - ], -] EXPECTED_UVM_CACHING_RW_SHARD_STORAGE = [ [ @@ -204,48 +145,6 @@ def get_expected_cache_aux_size(rows: int) -> int: ], ] -EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS = [ - [ - Storage(hbm=166352, ddr=1600), - Storage(hbm=166352, ddr=1600), - Storage(hbm=166120, ddr=800), - Storage(hbm=166120, ddr=800), - Storage(hbm=166120, ddr=800), - Storage(hbm=166120, ddr=800), - Storage(hbm=166120, ddr=800), - Storage(hbm=166120, ddr=800), - ], - [ - Storage(hbm=1002335, ddr=3520), - Storage(hbm=1002335, ddr=3520), - Storage(hbm=1001904, ddr=1760), - Storage(hbm=1001904, ddr=1760), - Storage(hbm=1001904, ddr=1760), - Storage(hbm=1001904, ddr=1760), - Storage(hbm=1001904, ddr=1760), - Storage(hbm=1001904, ddr=1760), - ], - [ - Storage(hbm=1004845, ddr=5760), - Storage(hbm=1004845, ddr=5760), - Storage(hbm=1004183, ddr=2880), - Storage(hbm=1004183, ddr=2880), - Storage(hbm=1004183, ddr=2880), - Storage(hbm=1004183, ddr=2880), - Storage(hbm=1004183, ddr=2880), - Storage(hbm=1004183, ddr=2880), - ], - [ - Storage(hbm=2649916, ddr=8320), - Storage(hbm=2649916, ddr=8320), - Storage(hbm=2648990, ddr=4160), - Storage(hbm=2648990, ddr=4160), - Storage(hbm=2648990, ddr=4160), - Storage(hbm=2648990, ddr=4160), - Storage(hbm=2648990, ddr=4160), - Storage(hbm=2648990, ddr=4160), - ], -] EXPECTED_TWRW_SHARD_SIZES = [ [[25, 20], [25, 20], [25, 20], [25, 20]], @@ -349,16 +248,6 @@ def compute_kernels( return [EmbeddingComputeKernel.FUSED.value] -class VirtualTableRWSharder(EmbeddingBagCollectionSharder): - def sharding_types(self, compute_device_type: str) -> List[str]: - return [ShardingType.ROW_WISE.value] - - def compute_kernels( - self, sharding_type: str, compute_device_type: str - ) -> List[str]: - return [EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value] - - class UVMCachingRWSharder(EmbeddingBagCollectionSharder): def sharding_types(self, compute_device_type: str) -> List[str]: return [ShardingType.ROW_WISE.value] @@ -468,27 +357,6 @@ def setUp(self) -> None: min_partition=40, pooling_factors=[2, 1, 3, 7] ), } - self._virtual_table_constraints = { - "table_0": ParameterConstraints( - min_partition=20, - compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value], - ), - "table_1": ParameterConstraints( - min_partition=20, - pooling_factors=[1, 3, 5], - compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value], - ), - "table_2": ParameterConstraints( - min_partition=20, - pooling_factors=[8, 2], - compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value], - ), - "table_3": ParameterConstraints( - min_partition=40, - pooling_factors=[2, 1, 3, 7], - compute_kernels=[EmbeddingComputeKernel.DRAM_VIRTUAL_TABLE.value], - ), - } self.num_tables = 4 tables = [ EmbeddingBagConfig( @@ -499,17 +367,6 @@ def setUp(self) -> None: ) for i in range(self.num_tables) ] - tables_with_buckets = [ - EmbeddingBagConfig( - num_embeddings=100 + i * 10, - embedding_dim=20 + i * 20, - name="table_" + str(i), - feature_names=["feature_" + str(i)], - total_num_buckets=10, - use_virtual_table=True, - ) - for i in range(self.num_tables) - ] weighted_tables = [ EmbeddingBagConfig( num_embeddings=(i + 1) * 10, @@ -520,9 +377,6 @@ def setUp(self) -> None: for i in range(4) ] self.model = TestSparseNN(tables=tables, weighted_tables=[]) - self.model_with_buckets = EmbeddingBagCollection( - tables=tables_with_buckets, - ) self.enumerator = EmbeddingEnumerator( topology=Topology( world_size=self.world_size, @@ -532,15 +386,6 @@ def setUp(self) -> None: batch_size=self.batch_size, constraints=self.constraints, ) - self.virtual_table_enumerator = EmbeddingEnumerator( - topology=Topology( - world_size=self.world_size, - compute_device=self.compute_device, - local_world_size=self.local_world_size, - ), - batch_size=self.batch_size, - constraints=self._virtual_table_constraints, - ) self.tower_model = TestTowerSparseNN( tables=tables, weighted_tables=weighted_tables ) @@ -669,26 +514,6 @@ def test_rw_sharding(self) -> None: EXPECTED_RW_SHARD_STORAGE[i], ) - def test_virtual_table_rw_sharding_with_buckets(self) -> None: - sharding_options = self.virtual_table_enumerator.enumerate( - self.model_with_buckets, - [cast(ModuleSharder[torch.nn.Module], VirtualTableRWSharder())], - ) - for i, sharding_option in enumerate(sharding_options): - self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value) - self.assertEqual( - [shard.size for shard in sharding_option.shards], - EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i], - ) - self.assertEqual( - [shard.offset for shard in sharding_option.shards], - EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i], - ) - self.assertEqual( - [shard.storage for shard in sharding_option.shards], - EXPECTED_VIRTUAL_TABLE_RW_SHARD_STORAGE_WITH_BUCKETS[i], - ) - def test_uvm_caching_rw_sharding(self) -> None: sharding_options = self.enumerator.enumerate( self.model, @@ -710,26 +535,6 @@ def test_uvm_caching_rw_sharding(self) -> None: EXPECTED_UVM_CACHING_RW_SHARD_STORAGE[i], ) - def test_uvm_caching_rw_sharding_with_buckets(self) -> None: - sharding_options = self.enumerator.enumerate( - self.model_with_buckets, - [cast(ModuleSharder[torch.nn.Module], UVMCachingRWSharder())], - ) - for i, sharding_option in enumerate(sharding_options): - self.assertEqual(sharding_option.sharding_type, ShardingType.ROW_WISE.value) - self.assertEqual( - [shard.size for shard in sharding_option.shards], - EXPECTED_RW_SHARD_SIZES_WITH_BUCKETS[i], - ) - self.assertEqual( - [shard.offset for shard in sharding_option.shards], - EXPECTED_RW_SHARD_OFFSETS_WITH_BUCKETS[i], - ) - self.assertEqual( - [shard.storage for shard in sharding_option.shards], - EXPECTED_UVM_CACHING_RW_SHARD_STORAGE_WITH_BUCKETS[i], - ) - def test_twrw_sharding(self) -> None: sharding_options = self.enumerator.enumerate( self.model, [cast(ModuleSharder[torch.nn.Module], TWRWSharder())] diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 9b27de2bd..ebc76b976 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -130,9 +130,6 @@ def create_input_dist( ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() - virtual_table_feature_num_buckets = ( - self._get_virtual_table_feature_num_buckets() - ) return RwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. @@ -143,7 +140,6 @@ def create_input_dist( is_sequence=True, has_feature_processor=self._has_feature_processor, need_pos=False, - virtual_table_feature_num_buckets=virtual_table_feature_num_buckets, ) def create_lookup( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 136052137..d310127c0 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -300,20 +300,6 @@ def _get_writable_feature_hash_sizes(self) -> List[int]: feature_hash_sizes.extend(group_config.feature_hash_sizes()) return feature_hash_sizes - def _get_virtual_table_feature_num_buckets(self) -> List[int]: - feature_num_buckets: List[int] = [] - for group_config in self._grouped_embedding_configs: - for embedding_table in group_config.embedding_tables: - if ( - embedding_table.total_num_buckets is not None - and embedding_table.use_virtual_table - ): - feature_num_buckets.extend( - [embedding_table.total_num_buckets] - * embedding_table.num_features() - ) - return feature_num_buckets - class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ @@ -345,7 +331,6 @@ def __init__( has_feature_processor: bool = False, need_pos: bool = False, keep_original_indices: bool = False, - virtual_table_feature_num_buckets: Optional[List[int]] = None, ) -> None: super().__init__() self._world_size: int = pg.size() @@ -355,33 +340,11 @@ def __init__( for i, hash_size in enumerate(feature_hash_sizes): block_divisor = self._world_size - # Using different num_bucket lists for MPZCH and KVZCH allows us to process them with - # different code paths for now, enabling uneven sharding for KVZCH only. MPZCH can have - # uneven sharding enabled for it as well in the future but that will require additional testing if feature_total_num_buckets is not None: - # MPZCH sharding - assert ( - feature_total_num_buckets[i] % self._world_size == 0 - ), f"Number of buckets: {feature_total_num_buckets[i]} should be divisible by world size: {self._world_size} for MPZCH" - + assert feature_total_num_buckets[i] % self._world_size == 0 block_divisor = feature_total_num_buckets[i] - elif virtual_table_feature_num_buckets is not None and len( - virtual_table_feature_num_buckets - ): - # KVZCH uneven sharding - assert ( - virtual_table_feature_num_buckets[i] >= self._world_size - ), f"Number of buckets: {virtual_table_feature_num_buckets[i]} for a table cannot be less than the word_size: {self._world_size}" - - block_divisor = virtual_table_feature_num_buckets[i] feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor) - self.kvzch_bucketize_row_pos: Optional[List[torch._tensor.Tensor]] = ( - self._get_bucketize_row_pos( - virtual_table_feature_num_buckets, feature_block_sizes - ) - ) - self.register_buffer( "_feature_block_sizes_tensor", torch.tensor( @@ -415,37 +378,6 @@ def __init__( self.unbucketize_permute_tensor: Optional[torch.Tensor] = None self._keep_original_indices = keep_original_indices - def _get_bucketize_row_pos( - self, - feature_num_buckets: Optional[List[int]], - feature_block_sizes: List[int], - ) -> Optional[List[torch.Tensor]]: - # Creates the bucketize row positions for uneven sharding with buckets. If the number of buckets - # is greater than the world size, and world_size % num_buckets != 0, the buckets count will not be - # the same on each rank. Bucketize_row_pos object lays out the distribution of buckets in this scenario. - # For eg. - # Bucketize_row_pos - # [ - # Tensor([0, 4, 8, 12, 15, 18, 21]), Feature 1 has 4 buckets on ranks 0, 1, 2. 3 buckets on ranks 3, 4, 5 - # Tensor([0, 2, 4, 6, 7, 8, 9]), Feature 2 has 2 buckets on ranks 0, 1, 2. 3 buckets on ranks 3, 4, 5 - # ] - if feature_num_buckets is None or len(feature_num_buckets) == 0: - return None - bucketize_row_pos = [[0] for _ in range(len(feature_num_buckets))] - bucketize_row_pos_tensors = [] - for feature in range(len(feature_num_buckets)): - for rank in range(self._world_size): - bucketize_row_pos[feature].append( - bucketize_row_pos[feature][-1] - + ( - (feature_num_buckets[feature] // self._world_size) - + int(rank < feature_num_buckets[feature] % self._world_size) - ) - * feature_block_sizes[feature] - ) - bucketize_row_pos_tensors.append(torch.tensor(bucketize_row_pos[feature])) - return bucketize_row_pos_tensors - def forward( self, sparse_features: KeyedJaggedTensor, @@ -481,7 +413,6 @@ def forward( else self._need_pos ), keep_original_indices=self._keep_original_indices, - block_bucketize_row_pos=self.kvzch_bucketize_row_pos, ) return self._dist(bucketized_features) @@ -627,9 +558,6 @@ def create_input_dist( ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]: num_features = self._get_num_features() feature_hash_sizes = self._get_feature_hash_sizes() - virtual_table_feature_num_buckets = ( - self._get_virtual_table_feature_num_buckets() - ) return RwSparseFeaturesDist( # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got # `Optional[ProcessGroup]`. @@ -639,7 +567,6 @@ def create_input_dist( device=device if device is not None else self._device, is_sequence=False, has_feature_processor=self._has_feature_processor, - virtual_table_feature_num_buckets=virtual_table_feature_num_buckets, need_pos=self._need_pos, ) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index d70750d2e..81e4fad8e 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -96,7 +96,6 @@ def calculate_shard_sizes_and_offsets( sharding_type: str, col_wise_shard_dim: Optional[int] = None, device_memory_sizes: Optional[List[int]] = None, - num_buckets: Optional[int] = None, ) -> Tuple[List[List[int]], List[List[int]]]: """ Calculates sizes and offsets for tensor sharded according to provided sharding type. @@ -123,12 +122,10 @@ def calculate_shard_sizes_and_offsets( return [[rows, columns]], [[0, 0]] elif sharding_type == ShardingType.ROW_WISE.value: return ( - _calculate_rw_shard_sizes_and_offsets( - rows, world_size, columns, num_buckets - ) + _calculate_rw_shard_sizes_and_offsets(rows, world_size, columns) if not device_memory_sizes else _calculate_uneven_rw_shard_sizes_and_offsets( - rows, world_size, columns, device_memory_sizes, num_buckets + rows, world_size, columns, device_memory_sizes ) ) elif sharding_type == ShardingType.TABLE_ROW_WISE.value: @@ -173,7 +170,7 @@ def _calculate_grid_shard_sizes_and_offsets( def _calculate_rw_shard_sizes_and_offsets( - hash_size: int, num_devices: int, columns: int, num_buckets: Optional[int] = None + hash_size: int, num_devices: int, columns: int ) -> Tuple[List[List[int]], List[List[int]]]: """ Sets prefix of shard_sizes to be `math.ceil(hash_size/num_devices)`. @@ -186,43 +183,21 @@ def _calculate_rw_shard_sizes_and_offsets( Also consider the example of hash_size = 5, num_devices = 4. The expected rows per rank is [2,2,1,0]. - - If num_buckets is specified, the sharding methodology changes to adapt to ZCH. - So, if hash_size = 10, num_devices = 4, num_buckets = 5, each bucket will have 2 rows. - After distributing the buckets evenly across ranks we will have the row distribution as - [4, 2, 2, 2] """ + + block_size: int = math.ceil(hash_size / num_devices) + last_rank: int = hash_size // block_size + last_block_size: int = hash_size - block_size * last_rank shard_sizes: List[List[int]] = [] - if num_buckets: - # number of buckets being specified means zch is enabled - assert ( - hash_size % num_buckets == 0 - ), "hash_size must be divisible by num_buckets" - bucket_size: int = hash_size // num_buckets - # number of buckets per rank - shard_buckets = math.floor(num_buckets / num_devices) - # number of ranks with an extra bucket - extra_bucket_shards = num_buckets % num_devices - for rank in range(num_devices): - if rank < extra_bucket_shards: - shard_size = bucket_size * (shard_buckets + 1) - else: - shard_size = bucket_size * shard_buckets - shard_sizes.append([shard_size, columns]) - else: - block_size: int = math.ceil(hash_size / num_devices) - last_rank: int = hash_size // block_size - last_block_size: int = hash_size - block_size * last_rank - shard_sizes: List[List[int]] = [] - - for rank in range(num_devices): - if rank < last_rank: - local_row: int = block_size - elif rank == last_rank: - local_row: int = last_block_size - else: - local_row: int = 0 - shard_sizes.append([local_row, columns]) + + for rank in range(num_devices): + if rank < last_rank: + local_row: int = block_size + elif rank == last_rank: + local_row: int = last_block_size + else: + local_row: int = 0 + shard_sizes.append([local_row, columns]) shard_offsets = [[0, 0]] for i in range(num_devices - 1): @@ -232,11 +207,7 @@ def _calculate_rw_shard_sizes_and_offsets( def _calculate_uneven_rw_shard_sizes_and_offsets( - hash_size: int, - num_devices: int, - columns: int, - device_memory_sizes: List[int], - num_buckets: Optional[int] = None, + hash_size: int, num_devices: int, columns: int, device_memory_sizes: List[int] ) -> Tuple[List[List[int]], List[List[int]]]: assert num_devices == len(device_memory_sizes), "must provide all the memory size" total_size = sum(device_memory_sizes) @@ -244,20 +215,10 @@ def _calculate_uneven_rw_shard_sizes_and_offsets( last_rank = num_devices - 1 processed_total_rows = 0 - if num_buckets is None: - num_buckets = hash_size - bucket_size = 1 - else: - assert ( - hash_size % num_buckets == 0 - ), "hash_size must be divisible by num_buckets" - bucket_size = hash_size // num_buckets + for rank in range(num_devices): if rank < last_rank: - local_row: int = ( - int(num_buckets * (device_memory_sizes[rank] / total_size)) - * bucket_size - ) + local_row: int = int(hash_size * (device_memory_sizes[rank] / total_size)) processed_total_rows += local_row elif rank == last_rank: local_row: int = hash_size - processed_total_rows diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 6c585b423..02f64e859 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -19,8 +19,6 @@ QuantManagedCollisionEmbeddingCollectionSharder, ) from torchrec.distributed.sharding_plan import ( - _calculate_rw_shard_sizes_and_offsets, - _calculate_uneven_rw_shard_sizes_and_offsets, column_wise, construct_module_sharding_plan, data_parallel, @@ -1239,147 +1237,3 @@ def test_module_to_default_sharders(self) -> None: default_sharder_map[QuantManagedCollisionEmbeddingCollection], QuantManagedCollisionEmbeddingCollectionSharder, ) - - -class RowWiseShardingTest(unittest.TestCase): - def test_non_zch_rw_sharding(self) -> None: - """Test the original row-wise sharding logic (without num_buckets)""" - # Test case 1: hash_size = 10, num_devices = 4 - shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( - hash_size=10, num_devices=4, columns=8 - ) - - # Expected: [3,3,3,1] rows per rank - expected_sizes = [[3, 8], [3, 8], [3, 8], [1, 8]] - expected_offsets = [[0, 0], [3, 0], [6, 0], [9, 0]] - - self.assertEqual(shard_sizes, expected_sizes) - self.assertEqual(shard_offsets, expected_offsets) - - # Test case 2: hash_size = 5, num_devices = 4 - shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( - hash_size=5, num_devices=4, columns=16 - ) - - # Expected: [2,2,1,0] rows per rank - expected_sizes = [[2, 16], [2, 16], [1, 16], [0, 16]] - expected_offsets = [[0, 0], [2, 0], [4, 0], [5, 0]] - - self.assertEqual(shard_sizes, expected_sizes) - self.assertEqual(shard_offsets, expected_offsets) - - def test_zch_rw_sharding(self) -> None: - """Test the new row-wise sharding logic with num_buckets (ZCH)""" - # Test case 1: hash_size = 10, num_devices = 4, num_buckets = 5 - # Each bucket has 2 rows, buckets distributed as [2,1,1,1] - # So rows are distributed as [4,2,2,2] - shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( - hash_size=10, num_devices=4, columns=8, num_buckets=5 - ) - - expected_sizes = [[4, 8], [2, 8], [2, 8], [2, 8]] - expected_offsets = [[0, 0], [4, 0], [6, 0], [8, 0]] - - self.assertEqual(shard_sizes, expected_sizes) - self.assertEqual(shard_offsets, expected_offsets) - - # Test case 2: hash_size = 100, num_devices = 4, num_buckets = 10 - # Each bucket has 10 rows, buckets distributed as [3,3,2,2] - # So rows are distributed as [30,30,20,20] - shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( - hash_size=100, num_devices=4, columns=16, num_buckets=10 - ) - - expected_sizes = [[30, 16], [30, 16], [20, 16], [20, 16]] - expected_offsets = [[0, 0], [30, 0], [60, 0], [80, 0]] - - self.assertEqual(shard_sizes, expected_sizes) - self.assertEqual(shard_offsets, expected_offsets) - - # Test case 3: hash_size = 18, num_devices = 3, num_buckets = 6 - # Each bucket has 3 rows (18 // 6 = 3), buckets distributed as [2,2,2] - # So rows are distributed as [6,6,6] - shard_sizes, shard_offsets = _calculate_rw_shard_sizes_and_offsets( - hash_size=18, num_devices=3, columns=32, num_buckets=6 - ) - - expected_sizes = [[6, 32], [6, 32], [6, 32]] - expected_offsets = [[0, 0], [6, 0], [12, 0]] - - self.assertEqual(shard_sizes, expected_sizes) - self.assertEqual(shard_offsets, expected_offsets) - - def test_uneven_rw_sharding_with_buckets(self) -> None: - """Test uneven row-wise sharding with num_buckets""" - # Test with device memory sizes [2, 1, 1] - device_memory_sizes = [2, 1, 1] - - # hash_size = 40, num_buckets = 8, bucket_size = 5 - # With memory ratio 2:1:1, buckets should be distributed as [4,2,2] - # So rows are distributed as [20,10,10] - shard_sizes, shard_offsets = _calculate_uneven_rw_shard_sizes_and_offsets( - hash_size=40, - num_devices=3, - columns=64, - device_memory_sizes=device_memory_sizes, - num_buckets=8, - ) - - expected_sizes = [[20, 64], [10, 64], [10, 64]] - expected_offsets = [[0, 0], [20, 0], [30, 0]] - - self.assertEqual(shard_sizes, expected_sizes) - self.assertEqual(shard_offsets, expected_offsets) - - # Test without num_buckets (should use hash_size as num_buckets) - # With memory ratio 2:1:1, rows should be distributed as [20,10,10] - shard_sizes, shard_offsets = _calculate_uneven_rw_shard_sizes_and_offsets( - hash_size=40, - num_devices=3, - columns=64, - device_memory_sizes=device_memory_sizes, - ) - - expected_sizes = [[20, 64], [10, 64], [10, 64]] - expected_offsets = [[0, 0], [20, 0], [30, 0]] - - self.assertEqual(shard_sizes, expected_sizes) - self.assertEqual(shard_offsets, expected_offsets) - - def test_rw_sharding_hash_size_not_divisible_by_num_buckets(self) -> None: - """Test that _calculate_rw_shard_sizes_and_offsets raises an assertion error when hash_size is not divisible by num_buckets""" - # Test case: hash_size = 10, num_buckets = 3 (not divisible) - with self.assertRaises(AssertionError): - _calculate_rw_shard_sizes_and_offsets( - hash_size=10, num_devices=4, columns=8, num_buckets=3 - ) - - # Test case: hash_size = 100, num_buckets = 7 (not divisible) - with self.assertRaises(AssertionError): - _calculate_rw_shard_sizes_and_offsets( - hash_size=100, num_devices=4, columns=16, num_buckets=7 - ) - - def test_uneven_rw_sharding_hash_size_not_divisible_by_num_buckets(self) -> None: - """Test that _calculate_uneven_rw_shard_sizes_and_offsets raises an assertion error when hash_size is not divisible by num_buckets""" - device_memory_sizes = [2, 1, 1] - - # Test case: hash_size = 10, num_buckets = 3 (not divisible) - with self.assertRaises(AssertionError): - _calculate_uneven_rw_shard_sizes_and_offsets( - hash_size=10, - num_devices=3, - columns=64, - device_memory_sizes=device_memory_sizes, - num_buckets=3, - ) - - # Test case: hash_size = 100, num_buckets = 7 (not divisible) - with self.assertRaises(AssertionError): - _calculate_uneven_rw_shard_sizes_and_offsets( - hash_size=100, - num_devices=3, - columns=64, - device_memory_sizes=device_memory_sizes, - num_buckets=7, - )