Skip to content

Commit f72772b

Browse files
wconstabpytorchmergebot
authored andcommitted
[PP] make runtime dbg log print custom actions (pytorch#167113)
Previously the log only printed if the default implementation for an action was used, now it prints before dispatching to custom registered actions. Tested by running on autoparallel graph runner and observing forward pass action logged Pull Request resolved: pytorch#167113 Approved by: https://github.com/sanketpurandare, https://github.com/Skylion007
1 parent 981dd71 commit f72772b

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

torch/distributed/pipelining/schedules.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,12 +2033,6 @@ def _perform_action(action: _Action) -> None:
20332033
is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage
20342034
is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage
20352035

2036-
logger.debug(
2037-
"_PipelineScheduleRuntime running time_step %d, action %s",
2038-
time_step,
2039-
action,
2040-
)
2041-
20422036
# TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
20432037
# since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be
20442038
# safe to use instead.
@@ -2191,6 +2185,11 @@ def _perform_action(action: _Action) -> None:
21912185
# count either full_backward or backward_weight together, to determine when to sync DP grads
21922186
self.backward_counter.clear()
21932187
for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
2188+
logger.debug(
2189+
"_PipelineScheduleRuntime running time_step %d, action %s",
2190+
time_step,
2191+
action,
2192+
)
21942193
try:
21952194
with record_function(_get_profiler_function_name(action)):
21962195
if action.computation_type in self._comp_type_to_function_map:

0 commit comments

Comments
 (0)