Skip to content

Commit 39c11c1

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
change error message and add commments for the Invalid Pipelineforward Usage assertion (#3294)
Summary: Pull Request resolved: #3294 # What is PipelinedForward in TorchRec ## TL;DR We have been seeing increasingly amount of posts (1, 2, 3, 4, …) reporting the assertion error regarding “Invalid PipelinedForward usage” The main reason is that TorchRec modifies the original model’s forward function to apply pipelining for the distributed embedding lookup, and doesn’t expect an embedding module to run more than once in the forward pass. The design of Torchrec’s PipelinedForward will be explained in detail, and we are willing to extend our support if you have a legitimate use case. ## Context TorchRec provides a solution for managing distributed embeddings. This is often necessary when embedding tables are too large to be stored on a single device. From the modeler’s point of view, TorchRec’s embedding modules handle those oversized embedding tables automatically by hiding the complex inter-device communications behind the scene, and letting the modeler focus on the architecture, as shown in the diagram below. {F1981323315} ***Fig. 1.*** *A modeler can treat TorchRec embedding modules as regular when authoring the model, demonstrating the forward pass in particular.* Following this, the TorchRec sharder substitutes the embedding modules with sharded embedding modules. Each sharded module manages only a portion of the embedding tables by chaining its three main components: input_dist, emb_lookup, and output_dist, and it is mathematically equivalent to its unsharded counterpart as shown below. {F1981323321} ***Fig. 2.*** *TorchRec’s distributed model parallel for embedding tables, demonstrating an EBC (EmbeddingBagCollection) is sharded into two sharded EBC in a 2-GPU environment.* A widely used (almost adopted by all RecSys models) pipelining optimization is to launch the input_dist in an earlier training batch so that the (modified) model’s forward can start with the embedding lookup followed by the output_dist, which becomes the pipelined_forward of the embedding module, as shown in the following figure. {F1981323324} ***Fig. 3.*** *By sharding the TorchRec’s embedding module, the model’s forward function is replaced with three chained functions: input_dist, emb_lookup, and output_dist. The input_dist is evoked in an earlier training batch while the rest forms the “pipelined forward” in the modified model’s forward function.* Although the input_dist and pipelined_forward are split into two training batches, they still follow a one-to-one correspondence, and here comes the assertion that “the result of the input_dist is consumed (by pipelined_forward) only once”, to make sure they are correctly chained together. ## Issues and Workarounds The most common issue is that an EBC module is called twice in the model’s forward. 1. If really need to run the forward pass twice (e.g., one with grad, one without grad), there are two possible workarounds to bypass the assertion: a) use TrainPipelineBase instead of TrainPipelineSparseDist. In the base pipeline, the input_dist is in the same forward call with emb_lookup and output_dist. b) call pipeline.detach() in pipeline.progress(). it basically restores the forward (as base pipeline) on the fly. Distributed embedding lookups are complex and extensive, the above workarounds would have significant performance regression. 2. If feeding the same EBC with two different inputs (KJTs), probably want to concatenate the KJTs into a single EBC call. 3. If re-using the embedding results of the EBC, you can just make a copy of the output embedding result instead of calling the EBC twice. Reviewed By: spmex Differential Revision: D80514683 fbshipit-source-id: 482f48aca33cb547c5de1db38acf19b312b94689
1 parent 08a5a82 commit 39c11c1

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

torchrec/distributed/train_pipeline/runtime_forwards.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ class PipelinedForward(BaseForward[TrainPipelineContext]):
7474
def __call__(self, *input, **kwargs) -> Awaitable:
7575
assert (
7676
self._name in self._context.input_dist_tensors_requests
77-
), "Invalid PipelinedForward usage, please do not directly call model.forward()"
77+
), f"Invalid PipelinedForward usage, input_dist of {self._name} is not available, probably consumed by others"
78+
# we made a basic assumption that an embedding module (EBC, EC, etc.) should only be evoked only
79+
# once in the model's forward pass. For more details: https://github.com/pytorch/torchrec/pull/3294
7880
request = self._context.input_dist_tensors_requests.pop(self._name)
7981
assert isinstance(request, Awaitable)
8082
with record_function("## wait_sparse_data_dist ##"):
@@ -121,7 +123,9 @@ def __call__(
121123
]:
122124
assert (
123125
self._name in self._context.embedding_a2a_requests
124-
), "Invalid EmbeddingPipelinedForward usage, please call pipeline.detach() before torch.no_grad() and/or model.forward()"
126+
), f"Invalid PipelinedForward usage, input_dist of {self._name} is not available, probably consumed by others"
127+
# we made a basic assumption that an embedding module (EBC, EC, etc.) should only be evoked only
128+
# once in the model's forward pass. For more details: https://github.com/pytorch/torchrec/pull/3294
125129

126130
ctx = self._context.module_contexts.pop(self._name)
127131
cur_stream = torch.get_device_module(self._device).current_stream()

0 commit comments

Comments
 (0)