Skip to content

Commit 0743ee1

Browse files
nipung90facebook-github-bot
authored andcommitted
Fix input_dist for when it has to handle KVZCH and non-KVZCH tables together for row-wise sharding (#3444)
Summary: In the KVZCH scenarios, the input_dist for row-wise sharding will have to deal with KVZCH and non-KVZCH features together. This means that the virtual_table_feature_num_buckets will have to represent non-KVZCH features too. In this diff, we default the value of the num_buckets to world_size for non-KVZCH features. Differential Revision: D84094039
1 parent 6de403e commit 0743ee1

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

torchrec/distributed/sharding/rw_sharding.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,11 @@ def _get_writable_feature_hash_sizes(self) -> List[int]:
301301
return feature_hash_sizes
302302

303303
def _get_virtual_table_feature_num_buckets(self) -> List[int]:
304+
"""
305+
Returns the number of buckets for each KVZCH feature in the GroupedEmbeddingConfigs.
306+
If a feature is not a KVZCH feature, the list will have world_size for that feature's corresponding position.
307+
This is needed as KVZCH features have to be processed for input_dist with non-KVZCH features.
308+
"""
304309
feature_num_buckets: List[int] = []
305310
for group_config in self._grouped_embedding_configs:
306311
for embedding_table in group_config.embedding_tables:
@@ -312,6 +317,10 @@ def _get_virtual_table_feature_num_buckets(self) -> List[int]:
312317
[embedding_table.total_num_buckets]
313318
* embedding_table.num_features()
314319
)
320+
else:
321+
feature_num_buckets.extend(
322+
[self._world_size] * embedding_table.num_features()
323+
)
315324
return feature_num_buckets
316325

317326

0 commit comments

Comments
 (0)