Skip to content

Commit d9ef101

Browse files
H-Huangpytorchmergebot
authored andcommitted
[PP] Optimize memory usage by releasing output memory earlier (pytorch#153383)
Considering `output_chunks` is only used for last stage, we should not keep the outputs of each stage in memory; this will allow memory to be freed earlier. Pull Request resolved: pytorch#153383 Approved by: https://github.com/Skylion007, https://github.com/kwen2501
1 parent f1de3f9 commit d9ef101

File tree

2 files changed

+62
-5
lines changed

2 files changed

+62
-5
lines changed

test/distributed/pipelining/test_stage.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,64 @@ def test_custom_dw_errors(self):
320320
with self.assertRaisesRegex(AssertionError, "backward_one_chunk"):
321321
stage_with_dw_builder.backward_weight_one_chunk(bwd_chunk_id=0)
322322

323+
@requires_nccl()
324+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
325+
def test_output_chunks_memory_usage(self):
326+
"""Test that output_chunks doesn't store memory for non-first stages."""
327+
full_mod = MultiMLP(d_hid, n_layers=self.world_size)
328+
full_mod.to(self.device)
329+
stage_mod = full_mod.get_submodule(f"layers.{self.rank}")
330+
x = torch.randn(batch_size, d_hid, device=self.device)
331+
target = torch.randn(batch_size, d_hid, device=self.device)
332+
stage = PipelineStage(
333+
stage_mod,
334+
self.rank,
335+
self.world_size,
336+
self.device,
337+
)
338+
self.assertEqual(
339+
len(stage.output_chunks), 0, "output_chunks should be empty initially"
340+
)
341+
342+
schedule = ScheduleGPipe(
343+
stage, chunks, loss_fn=torch.nn.MSELoss(reduction="sum")
344+
)
345+
346+
def _run_step(x):
347+
if self.rank == 0:
348+
return schedule.step(x)
349+
elif self.rank == self.world_size - 1:
350+
return schedule.step(target=target)
351+
else:
352+
return schedule.step()
353+
354+
_run_step(x)
355+
356+
# Verify fwd_cache is empty
357+
self.assertEqual(len(stage.fwd_cache), 0, "fwd_cache should be cleared")
358+
359+
# Check output_chunks state after step
360+
if self.rank == self.world_size - 1:
361+
self.assertEqual(
362+
len(stage.output_chunks),
363+
chunks,
364+
"Last stage should store output chunks",
365+
)
366+
else:
367+
self.assertEqual(
368+
len(stage.output_chunks),
369+
0,
370+
f"Non-last stage (rank {self.rank}) should not store output chunks",
371+
)
372+
373+
# Clear the schedule and stage caches
374+
stage.clear_runtime_states()
375+
if self.rank == self.world_size - 1:
376+
# Last stage should have output_chunks populated
377+
self.assertEqual(
378+
len(stage.output_chunks), 0, "Last stage should store output chunks"
379+
)
380+
323381

324382
instantiate_parametrized_tests(StageTest)
325383

torch/distributed/pipelining/stage.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,7 @@ def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]:
433433
"""
434434
Get the activation send ops for current stage's forward.
435435
"""
436-
output = self.output_chunks[fwd_chunk_id]
437-
# Unify output form to tuple for easy correspondance with
438-
# `act_send_info`
439-
output_tuple = output if type(output) is tuple else (output,)
436+
output_tuple, _ = self.fwd_cache[fwd_chunk_id]
440437

441438
ops: list[dist.P2POp] = []
442439

@@ -719,7 +716,9 @@ def forward_one_chunk(
719716
output_tuple = _normalize_model_output_as_tuple(output)
720717

721718
# Prepare for final output merge or reduction
722-
self.output_chunks.append(output)
719+
# Output chunks is only used for the last stage since we only merge the output of the last stage
720+
if self.is_last:
721+
self.output_chunks.append(output)
723722

724723
# Save activations and inputs for backward
725724
flat_args = flatten_args(composite_args)

0 commit comments

Comments
 (0)