Skip to content

Commit a4832f4

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add batch index in the fwd/bwd (meta-pytorch#3326)
Summary: Pull Request resolved: meta-pytorch#3326 # context * with more complex pipeline technique it's very helpful to know which batch an action is on * this diff add the batch-id into the trace to clearly mark the fwd and bwd passes. * old trace {F1981558168} * new trace {F1981558180} Reviewed By: spmex Differential Revision: D81159416 fbshipit-source-id: 3889942f7564278352ec9deab4e52f49ff2cc2d7
1 parent f61e127 commit a4832f4

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,13 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
640640
return
641641

642642
def _wait_for_batch(self) -> None:
643-
with record_function("## wait_for_batch ##"):
643+
batch_id = self.contexts[0].index if len(self.contexts) > 0 else "?"
644+
with record_function(f"## wait_for_batch {batch_id} ##"):
644645
_wait_for_batch(cast(In, self.batches[0]), self._data_dist_stream)
645646

646647
def _backward(self, losses: torch.Tensor) -> None:
647-
with record_function("## backward ##"):
648+
batch_id = self.contexts[0].index if len(self.contexts) > 0 else "?"
649+
with record_function(f"## backward {batch_id} ##"):
648650
torch.sum(losses, dim=0).backward()
649651

650652
def progress(self, dataloader_iter: Iterator[In]) -> Out:
@@ -688,7 +690,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
688690
self.enqueue_batch(dataloader_iter)
689691

690692
# forward
691-
with record_function("## forward ##"):
693+
with record_function(f"## forward {self.contexts[0].index} ##"):
692694
self._state = PipelineState.CALL_FWD
693695
losses, output = self._model_fwd(self.batches[0])
694696

@@ -714,7 +716,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
714716
)
715717

716718
# update
717-
with record_function("## optimizer ##"):
719+
with record_function(f"## optimizer {self.contexts[0].index} ##"):
718720
self._optimizer.step()
719721

720722
self.dequeue_batch()
@@ -1063,7 +1065,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
10631065
self.enqueue_batch(dataloader_iter)
10641066

10651067
# forward
1066-
with record_function("## forward ##"):
1068+
with record_function(f"## forward {self.contexts[0].index} ##"):
10671069
losses, output = self._model_fwd(self.batches[0])
10681070

10691071
if len(self.batches) >= 2:

0 commit comments

Comments
 (0)