Skip to content

Commit e289caa

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
Fully Sharded 2D Parallelism (#3558)
Summary: **This diff introduces Fully Sharded 2D Parallelism in TorchRec. It brings forth significant memory (50%+) savings by sharding embedding tables when they are not in use.** After the embedding lookup, the embedding table is further sharded across the data parallel dimension until it is needed in the backward pass. This allows model layers after the embedding lookup to have more memory headroom. Enabling further scaling of the dense architecture. **Practically speaking, this saves 50%+ embedding memory per GPU which account for upwards of 10GB of memory saving on large models.** The peak memory during this step becomes, ```O(shard + shard/num_replication)```, which then leads to an embedding memory of ```O(shard/num_replication)``` after the lookup step. The memory free and collective communications are done in a overhead free manner by maximizing computation and communication collectives through asynchronous handling on multiple streams. With Fully Sharded 2D, the embedding weight synchronization has to happen every step or trained batches are lost across ranks. We use an asynchronous reduce scatter after the embedding lookup step. We are able to fully overlap this collective with compute to expose no additional overhead. A new awaitable is introduced, ```ReduceScatterResizeAwaitable``` under the Fully Sharded path that is called with SDD output_dist all to all. This awaitable ```wait()```s on the async reduce scatter and calls the ```resize()``` operation on the embedding memory ensuring no race conditions. Users can enable fully sharded 2D through, a new arg `ShardingStrategy` ``` DMPCollection(..., sharding_strategy=ShardingStrategy.FULLY_SHARDED) ``` This is part of our work to create an overhead free 2D parallel which will allow us to use it for every model. Remaining work from this diff is to launch an async all gather in the backward pass, making planner aware of such memory savings, and integrate this work with per module 2D Differential Revision: D82253387
1 parent 217889e commit e289caa

16 files changed

+512
-31
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 364 additions & 0 deletions
Large diffs are not rendered by default.

torchrec/distributed/embedding.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ def __init__(
343343
module_fqn: Optional[str] = None,
344344
sharding_types: Optional[List[str]] = None,
345345
use_gather_select: bool = False,
346+
resize_awaitables: Optional[List[Awaitable[torch.Tensor]]] = None,
346347
) -> None:
347348
super().__init__()
348349
self._awaitables_per_sharding = awaitables_per_sharding
@@ -354,6 +355,7 @@ def __init__(
354355
self._module_fqn = module_fqn
355356
self._sharding_types = sharding_types
356357
self._use_gather_select = use_gather_select
358+
self._resize_awaitables = resize_awaitables
357359

358360
def _wait_impl(self) -> Dict[str, JaggedTensor]:
359361
jt_dict: Dict[str, JaggedTensor] = {}
@@ -398,6 +400,12 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
398400
use_gather_select=self._use_gather_select,
399401
)
400402
)
403+
404+
# free memory and resize
405+
# pyre-ignore[16]
406+
for awaitable in self._resize_awaitables:
407+
awaitable.wait()
408+
401409
return jt_dict
402410

403411

@@ -1588,6 +1596,8 @@ def compute_and_output_dist(
15881596
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
15891597
awaitables_per_sharding: List[Awaitable[torch.Tensor]] = []
15901598
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []
1599+
resize_awaitables = []
1600+
15911601
for lookup, odist, features, sharding_ctx, sharding_type in zip(
15921602
self._lookups,
15931603
self._output_dists,
@@ -1604,6 +1614,9 @@ def compute_and_output_dist(
16041614
EmbeddingEvent.LOOKUP, self._module_fqn, sharding_type
16051615
):
16061616
embs = lookup(features)
1617+
if hasattr(lookup, "get_resize_awaitables"):
1618+
# pyre-ignore[29]
1619+
resize_awaitables.extend(lookup.get_resize_awaitables())
16071620
if self.post_lookup_tracker_fn is not None:
16081621
self.post_lookup_tracker_fn(features, embs, self, None)
16091622

@@ -1631,6 +1644,7 @@ def compute_and_output_dist(
16311644
module_fqn=self._module_fqn,
16321645
sharding_types=list(self._sharding_type_to_sharding.keys()),
16331646
use_gather_select=self._use_gather_select,
1647+
resize_awaitables=resize_awaitables,
16341648
)
16351649

16361650
def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:

torchrec/distributed/embedding_lookup.py

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
BatchedFusedEmbeddingBag,
4040
KeyValueEmbedding,
4141
KeyValueEmbeddingBag,
42+
ShardedBatchedFusedEmbedding,
43+
ShardedBatchedFusedEmbeddingBag,
4244
ZeroCollisionEmbeddingCache,
4345
ZeroCollisionKeyValueEmbedding,
4446
ZeroCollisionKeyValueEmbeddingBag,
@@ -65,7 +67,15 @@
6567
QuantBatchedEmbedding,
6668
QuantBatchedEmbeddingBag,
6769
)
68-
from torchrec.distributed.types import rank_device, ShardedTensor, ShardingType
70+
from torchrec.distributed.types import (
71+
LazyAwaitable,
72+
rank_device,
73+
ShardedTensor,
74+
ShardingEnv,
75+
ShardingEnv2D,
76+
ShardingStrategy,
77+
ShardingType,
78+
)
6979
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
7080

7181
logger: logging.Logger = logging.getLogger(__name__)
@@ -185,12 +195,15 @@ def __init__(
185195
grouped_configs: List[GroupedEmbeddingConfig],
186196
pg: Optional[dist.ProcessGroup] = None,
187197
device: Optional[torch.device] = None,
198+
env: Optional[ShardingEnv] = None,
188199
) -> None:
189200
super().__init__()
190201
self._emb_modules: nn.ModuleList = nn.ModuleList()
191202
self._need_prefetch: bool = False
192203
for config in grouped_configs:
193-
self._emb_modules.append(self._create_embedding_kernel(config, pg, device))
204+
self._emb_modules.append(
205+
self._create_embedding_kernel(config, pg, device, env)
206+
)
194207

195208
self._feature_splits: List[int] = []
196209
for config in grouped_configs:
@@ -218,6 +231,7 @@ def _create_embedding_kernel(
218231
config: GroupedEmbeddingConfig,
219232
pg: Optional[dist.ProcessGroup],
220233
device: Optional[torch.device],
234+
env: Optional[ShardingEnv] = None,
221235
) -> BaseEmbedding:
222236
for table in config.embedding_tables:
223237
if (
@@ -234,11 +248,20 @@ def _create_embedding_kernel(
234248
device=device,
235249
)
236250
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
237-
return BatchedFusedEmbedding(
238-
config=config,
239-
pg=pg,
240-
device=device,
241-
)
251+
if (
252+
env
253+
and isinstance(env, ShardingEnv2D)
254+
and env.sharding_strategy == ShardingStrategy.FULLY_SHARDED
255+
):
256+
return ShardedBatchedFusedEmbedding(
257+
config=config, pg=pg, device=device, env=env
258+
)
259+
else:
260+
return BatchedFusedEmbedding(
261+
config=config,
262+
pg=pg,
263+
device=device,
264+
)
242265
elif config.compute_kernel == EmbeddingComputeKernel.KEY_VALUE:
243266
return KeyValueEmbedding(
244267
config=config,
@@ -329,6 +352,14 @@ def forward(
329352

330353
return embeddings_cat_empty_rank_handle(embeddings, self._dummy_embs_tensor)
331354

355+
def get_resize_awaitables(self) -> List[LazyAwaitable[torch.Tensor]]:
356+
# TODO - we can probably do some smart grouping to make this more efficient
357+
return [
358+
emb_module.get_rs_awaitable() # pyre-ignore[29]
359+
for emb_module in self._emb_modules
360+
if hasattr(emb_module, "get_rs_awaitable")
361+
]
362+
332363
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
333364
def state_dict(
334365
self,
@@ -512,12 +543,14 @@ def __init__(
512543
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
513544
scale_weight_gradients: bool = True,
514545
sharding_type: Optional[ShardingType] = None,
546+
env: Optional[ShardingEnv] = None,
515547
) -> None:
516548
super().__init__()
549+
self._env = env
517550
self._emb_modules: nn.ModuleList = nn.ModuleList()
518551
for config in grouped_configs:
519552
self._emb_modules.append(
520-
self._create_embedding_kernel(config, device, pg, sharding_type)
553+
self._create_embedding_kernel(config, device, pg, sharding_type, env)
521554
)
522555

523556
self._feature_splits: List[int] = []
@@ -555,6 +588,7 @@ def _create_embedding_kernel(
555588
device: Optional[torch.device],
556589
pg: Optional[dist.ProcessGroup],
557590
sharding_type: Optional[ShardingType],
591+
env: Optional[ShardingEnv],
558592
) -> BaseEmbedding:
559593
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
560594
return BatchedDenseEmbeddingBag(
@@ -564,12 +598,26 @@ def _create_embedding_kernel(
564598
sharding_type=sharding_type,
565599
)
566600
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
567-
return BatchedFusedEmbeddingBag(
568-
config=config,
569-
pg=pg,
570-
device=device,
571-
sharding_type=sharding_type,
572-
)
601+
if (
602+
env
603+
and isinstance(env, ShardingEnv2D)
604+
and env.sharding_strategy == ShardingStrategy.FULLY_SHARDED
605+
):
606+
return ShardedBatchedFusedEmbeddingBag(
607+
config=config,
608+
pg=pg,
609+
device=device,
610+
sharding_type=sharding_type,
611+
env=env,
612+
)
613+
else:
614+
return BatchedFusedEmbeddingBag(
615+
config=config,
616+
pg=pg,
617+
device=device,
618+
sharding_type=sharding_type,
619+
env=env,
620+
)
573621
elif config.compute_kernel in {
574622
EmbeddingComputeKernel.KEY_VALUE,
575623
}:
@@ -744,6 +792,14 @@ def forward(
744792
dim=1,
745793
)
746794

795+
def get_resize_awaitables(self) -> List[LazyAwaitable[torch.Tensor]]:
796+
# TODO - we can probably do some smart grouping to make this more efficient
797+
return [
798+
emb_module.get_rs_awaitable() # pyre-ignore[29]
799+
for emb_module in self._emb_modules
800+
if hasattr(emb_module, "get_rs_awaitable")
801+
]
802+
747803
# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
748804
def state_dict(
749805
self,

torchrec/distributed/embeddingbag.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,15 @@ def __init__(
407407
embedding_names: List[str],
408408
module_fqn: Optional[str] = None,
409409
sharding_types: Optional[List[str]] = None,
410+
resize_awaitables: Optional[List[Awaitable[torch.Tensor]]] = None,
410411
) -> None:
411412
super().__init__()
412413
self._awaitables = awaitables
413414
self._embedding_dims = embedding_dims
414415
self._embedding_names = embedding_names
415416
self._module_fqn = module_fqn
416417
self._sharding_types = sharding_types
418+
self._resize_awaitables = resize_awaitables
417419

418420
def _wait_impl(self) -> KeyedTensor:
419421
embeddings = []
@@ -425,6 +427,12 @@ def _wait_impl(self) -> KeyedTensor:
425427
):
426428
embeddings.append(w.wait())
427429

430+
# free memory and resize
431+
if self._resize_awaitables is not None:
432+
# pyre-ignore[16]
433+
for awaitable in self._resize_awaitables:
434+
awaitable.wait()
435+
428436
return construct_output_kt(
429437
embeddings=embeddings,
430438
embedding_names=self._embedding_names,
@@ -1655,6 +1663,7 @@ def compute_and_output_dist(
16551663
"""
16561664
batch_size_per_feature_pre_a2a = []
16571665
awaitables = []
1666+
resize_awaitables = []
16581667

16591668
# No usage of zip for dynamo
16601669
for i in range(len(self._lookups)):
@@ -1669,7 +1678,11 @@ def compute_and_output_dist(
16691678
self._module_fqn,
16701679
sharding_type,
16711680
):
1681+
# with fully sharded 2D enabled, it returns an awaitable for the reduce scatter and resize operation
16721682
embs = lookup(features)
1683+
if hasattr(lookup, "get_resize_awaitables"):
1684+
# pyre-ignore[29]
1685+
resize_awaitables.extend(lookup.get_resize_awaitables())
16731686
if self.post_lookup_tracker_fn is not None:
16741687
self.post_lookup_tracker_fn(features, embs, self, None)
16751688

@@ -1710,6 +1723,7 @@ def compute_and_output_dist(
17101723
embedding_names=self._embedding_names,
17111724
module_fqn=self._module_fqn,
17121725
sharding_types=self._sharding_types,
1726+
resize_awaitables=resize_awaitables,
17131727
)
17141728

17151729
# register callback if there are features that need mean pooling

0 commit comments

Comments
 (0)