You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
change error message and add commments for the Invalid Pipelineforward Usage assertion (meta-pytorch#3294)
Summary:
Pull Request resolved: meta-pytorch#3294
# What is PipelinedForward in TorchRec
## TL;DR
We have been seeing increasingly amount of posts (1, 2, 3, 4, …) reporting the assertion error regarding “Invalid PipelinedForward usage”
The main reason is that TorchRec modifies the original model’s forward function to apply pipelining for the distributed embedding lookup, and doesn’t expect an embedding module to run more than once in the forward pass.
The design of Torchrec’s PipelinedForward will be explained in detail, and we are willing to extend our support if you have a legitimate use case.
## Context
TorchRec provides a solution for managing distributed embeddings. This is often necessary when embedding tables are too large to be stored on a single device. From the modeler’s point of view, TorchRec’s embedding modules handle those oversized embedding tables automatically by hiding the complex inter-device communications behind the scene, and letting the modeler focus on the architecture, as shown in the diagram below.
{F1981323315}
***Fig. 1.*** *A modeler can treat TorchRec embedding modules as regular when authoring the model, demonstrating the forward pass in particular.*
Following this, the TorchRec sharder substitutes the embedding modules with sharded embedding modules. Each sharded module manages only a portion of the embedding tables by chaining its three main components: input_dist, emb_lookup, and output_dist, and it is mathematically equivalent to its unsharded counterpart as shown below.
{F1981323321}
***Fig. 2.*** *TorchRec’s distributed model parallel for embedding tables, demonstrating an EBC (EmbeddingBagCollection) is sharded into two sharded EBC in a 2-GPU environment.*
A widely used (almost adopted by all RecSys models) pipelining optimization is to launch the input_dist in an earlier training batch so that the (modified) model’s forward can start with the embedding lookup followed by the output_dist, which becomes the pipelined_forward of the embedding module, as shown in the following figure.
{F1981323324}
***Fig. 3.*** *By sharding the TorchRec’s embedding module, the model’s forward function is replaced with three chained functions: input_dist, emb_lookup, and output_dist. The input_dist is evoked in an earlier training batch while the rest forms the “pipelined forward” in the modified model’s forward function.*
Although the input_dist and pipelined_forward are split into two training batches, they still follow a one-to-one correspondence, and here comes the assertion that “the result of the input_dist is consumed (by pipelined_forward) only once”, to make sure they are correctly chained together.
## Issues and Workarounds
The most common issue is that an EBC module is called twice in the model’s forward.
1. If really need to run the forward pass twice (e.g., one with grad, one without grad), there are two possible workarounds to bypass the assertion:
a) use TrainPipelineBase instead of TrainPipelineSparseDist. In the base pipeline, the input_dist is in the same forward call with emb_lookup and output_dist.
b) call pipeline.detach() in pipeline.progress(). it basically restores the forward (as base pipeline) on the fly.
Distributed embedding lookups are complex and extensive, the above workarounds would have significant performance regression.
2. If feeding the same EBC with two different inputs (KJTs), probably want to concatenate the KJTs into a single EBC call.
3. If re-using the embedding results of the EBC, you can just make a copy of the output embedding result instead of calling the EBC twice.
Reviewed By: spmex
Differential Revision: D80514683
fbshipit-source-id: 482f48aca33cb547c5de1db38acf19b312b94689
0 commit comments