Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def _get_writable_feature_hash_sizes(self) -> List[int]:
return feature_hash_sizes

def _get_virtual_table_feature_num_buckets(self) -> List[int]:
"""
Returns the number of buckets for each KVZCH feature in the GroupedEmbeddingConfigs.
If a feature is not a KVZCH feature, the list will have world_size for that feature's corresponding position.
This is needed as KVZCH features have to be processed for input_dist with non-KVZCH features.
"""
feature_num_buckets: List[int] = []
for group_config in self._grouped_embedding_configs:
for embedding_table in group_config.embedding_tables:
Expand All @@ -312,6 +317,10 @@ def _get_virtual_table_feature_num_buckets(self) -> List[int]:
[embedding_table.total_num_buckets]
* embedding_table.num_features()
)
else:
feature_num_buckets.extend(
[self._world_size] * embedding_table.num_features()
)
return feature_num_buckets


Expand Down
Loading