Skip to content

Commit fdb46f1

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
add memory snapshot support in benchmark_core (#3437)
Summary: Pull Request resolved: #3437 # context * add memory snapshot support in the torchrec benchmark core NOTE: memory snapshot runs with the profiler (enabled by providing a valid `profile_dir`), so the trace file is also available * instructions: https://pytorch.org/blog/understanding-gpu-memory-1/ * visualization tool: https://docs.pytorch.org/memory_viz * example snapshot rank-0 {F1982513266} rank-1 {F1982513396} Reviewed By: spmex Differential Revision: D83991566 fbshipit-source-id: 043720950f5a52e23bb1be43634d8667f893eeb8
1 parent d991b46 commit fdb46f1

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

torchrec/distributed/benchmark/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ python -m torchrec.distributed.benchmark.benchmark_train_pipeline \
1717
- internal:
1818
```
1919
buck2 run @fbcode//mode/opt fbcode//torchrec/distributed/benchmark:benchmark_comms -- \
20-
a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10)
20+
a2a_single --name=a2a_sync_base-$(hg whereami | cut -c 1-10) --memory_snapshot=true
2121
```
2222
- oss:
2323
```
2424
python -m torchrec.distributed.benchmark.benchmark_comms \
25-
a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER)
25+
a2a_single --name=a2a_sync_base-$(git rev-parse --short HEAD || echo $USER) --memory_snapshot=true
2626
```

torchrec/distributed/benchmark/base.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
]
7676

7777
EMBEDDING_DIM: int = 128
78+
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT = 100_000
7879

7980

8081
class CompileMode(Enum):
@@ -602,6 +603,7 @@ def _run_benchmark_core(
602603
export_stacks: bool = False,
603604
reset_accumulated_memory_stats: bool = False,
604605
all_rank_traces: bool = False,
606+
memory_snapshot: bool = False,
605607
) -> BenchmarkResult:
606608
"""Internal helper that contains the core benchmarking logic shared by
607609
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
@@ -736,6 +738,10 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
736738
f"{output_dir}/stacks-cuda-{name}.stacks", "self_cuda_time_total"
737739
)
738740

741+
if memory_snapshot:
742+
torch.cuda.memory._record_memory_history(
743+
max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
744+
)
739745
with torch.profiler.profile(
740746
activities=[
741747
torch.profiler.ProfilerActivity.CPU,
@@ -757,6 +763,17 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
757763
else:
758764
torch.cuda.synchronize(rank)
759765

766+
if memory_snapshot:
767+
try:
768+
torch.cuda.memory._dump_snapshot(
769+
f"{output_dir}/memory-{name}-rank{rank}.pickle"
770+
)
771+
except Exception as e:
772+
logger.error(f"Failed to capture memory snapshot {e}")
773+
774+
# Stop recording memory snapshot history.
775+
torch.cuda.memory._record_memory_history(enabled=None)
776+
760777
return BenchmarkResult(
761778
short_name=name,
762779
gpu_elapsed_time=gpu_elapsed_time,
@@ -831,6 +848,7 @@ class BenchFuncConfig:
831848
pre_gpu_load: int = 0
832849
export_stacks: bool = False
833850
all_rank_traces: bool = False
851+
memory_snapshot: bool = False
834852

835853
# pyre-ignore [2]
836854
def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
@@ -844,6 +862,7 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
844862
"pre_gpu_load": self.pre_gpu_load,
845863
"export_stacks": self.export_stacks,
846864
"all_rank_traces": self.all_rank_traces,
865+
"memory_snapshot": self.memory_snapshot,
847866
} | kwargs_to_override
848867

849868

@@ -862,6 +881,7 @@ def benchmark_func(
862881
pre_gpu_load: int = 0,
863882
export_stacks: bool = False,
864883
all_rank_traces: bool = False,
884+
memory_snapshot: bool = False,
865885
) -> BenchmarkResult:
866886
"""
867887
Args:
@@ -870,7 +890,7 @@ def benchmark_func(
870890
stats. ``rank == -1`` means single-process mode.
871891
872892
func_to_benchmark: Callable that executes one measured iteration.
873-
func_to_benchmark(batch_inputs, **kwargs)
893+
func_to_benchmark(batch_inputs, **kwargs) -> None
874894
bench_inputs, prof_inputs: List[Dict[str, Any]] this argument will be fed
875895
to the function at once, and bench_inputs will be used for benchmarking
876896
while prof_inputs will be used for profiling
@@ -885,6 +905,7 @@ def benchmark_func(
885905
measured iteration (helps simulating a loaded allocator).
886906
export_stacks: Whether to export flamegraph-compatible stack files.
887907
all_rank_traces: Whether to export traces from all ranks.
908+
memory_snapshot: Whether to capture memory snapshot during the profiling
888909
"""
889910
if benchmark_func_kwargs is None:
890911
benchmark_func_kwargs = {}
@@ -912,4 +933,5 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None:
912933
export_stacks=export_stacks,
913934
reset_accumulated_memory_stats=True,
914935
all_rank_traces=all_rank_traces,
936+
memory_snapshot=memory_snapshot,
915937
)

0 commit comments

Comments
 (0)