Skip to content

Commit 4013381

Browse files
kausvfacebook-github-bot
authored andcommitted
Fix a2a type (#3311)
Summary: Pull Request resolved: #3311 VBE initializes dist but kjt sets ctx flag at runtime. So if the batch sizes happens to match for all features, we assume fixed batch size resulting in runtime error. In this did, I fix the dist once initialized. We should follow up with driving this from config. https://fb.workplace.com/groups/1699838000485189/permalink/2222934654842185/ Differential Revision: D80742183 fbshipit-source-id: 1898040fd436a54742f78594f996a7f4e5e0225c
1 parent 1b1e2b3 commit 4013381

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

torchrec/distributed/sharding/tw_sharding.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -367,20 +367,21 @@ def forward(
367367
"""
368368
if self._dist is None:
369369
self._create_output_dist_module(sharding_ctx)
370-
371-
if sharding_ctx is None:
372-
return cast(PooledEmbeddingsAllToAll, self._dist)(local_embs)
373-
elif sharding_ctx.variable_batch_per_feature:
370+
if isinstance(self._dist, VariableBatchPooledEmbeddingsAllToAll):
371+
sharding_ctx = none_throws(sharding_ctx)
374372
return cast(VariableBatchPooledEmbeddingsAllToAll, self._dist)(
375373
local_embs,
376-
batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature,
377-
batch_size_per_feature_pre_a2a=sharding_ctx.batch_size_per_feature_pre_a2a,
378-
)
379-
else:
380-
return cast(PooledEmbeddingsAllToAll, self._dist)(
381-
local_embs,
382-
batch_size_per_rank=sharding_ctx.batch_size_per_rank,
374+
batch_size_per_rank_per_feature=sharding_ctx.batch_size_per_rank_per_feature
375+
or sharding_ctx.batch_size_per_rank,
376+
batch_size_per_feature_pre_a2a=sharding_ctx.batch_size_per_feature_pre_a2a
377+
or sharding_ctx.batch_size_per_rank,
383378
)
379+
return cast(PooledEmbeddingsAllToAll, self._dist)(
380+
local_embs,
381+
batch_size_per_rank=(
382+
sharding_ctx.batch_size_per_rank if sharding_ctx else None
383+
),
384+
)
384385

385386
def _create_output_dist_module(
386387
self, sharding_ctx: Optional[EmbeddingShardingContext] = None

torchrec/distributed/sharding/twrw_sharding.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,14 +504,21 @@ def forward(
504504
self._create_output_dist_modules(sharding_ctx)
505505
local_rank = self._rank % self._intra_pg.size()
506506
current_node = self._rank // self._intra_pg.size()
507-
if sharding_ctx is not None and sharding_ctx.variable_batch_per_feature:
507+
if isinstance(
508+
self._intra_dist, VariableBatchPooledEmbeddingsReduceScatter
509+
) and isinstance(self._cross_dist, VariableBatchPooledEmbeddingsAllToAll):
510+
assert sharding_ctx is not None and (
511+
sharding_ctx.batch_size_per_rank_per_feature
512+
or sharding_ctx.batch_size_per_rank
513+
), "Batch size not found in KJT input for VBE"
508514
(
509515
batch_size_per_rank_per_feature_by_cross_group,
510516
batch_size_per_feature_sum_by_cross_group,
511517
) = self._preprocess_batch_size_per_rank_per_feature(
512518
self._intra_pg.size(),
513519
self._cross_pg.size(),
514-
sharding_ctx.batch_size_per_rank_per_feature,
520+
sharding_ctx.batch_size_per_rank_per_feature
521+
or [sharding_ctx.batch_size_per_rank],
515522
)
516523
rs_result = cast(
517524
VariableBatchPooledEmbeddingsReduceScatter, self._intra_dist
@@ -525,7 +532,8 @@ def forward(
525532
batch_size_per_rank_per_feature=batch_size_per_rank_per_feature_by_cross_group[
526533
local_rank
527534
],
528-
batch_size_per_feature_pre_a2a=sharding_ctx.batch_size_per_feature_pre_a2a,
535+
batch_size_per_feature_pre_a2a=sharding_ctx.batch_size_per_feature_pre_a2a
536+
or sharding_ctx.batch_size_per_rank,
529537
)
530538
elif (
531539
sharding_ctx is not None and len(set(sharding_ctx.batch_size_per_rank)) > 1

0 commit comments

Comments
 (0)