75
75
]
76
76
77
77
EMBEDDING_DIM : int = 128
78
+ MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT = 100_000
78
79
79
80
80
81
class CompileMode (Enum ):
@@ -602,6 +603,7 @@ def _run_benchmark_core(
602
603
export_stacks : bool = False ,
603
604
reset_accumulated_memory_stats : bool = False ,
604
605
all_rank_traces : bool = False ,
606
+ memory_snapshot : bool = False ,
605
607
) -> BenchmarkResult :
606
608
"""Internal helper that contains the core benchmarking logic shared by
607
609
``benchmark`` and ``benchmark_func``. All heavy–lifting (timing, memory
@@ -736,6 +738,10 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
736
738
f"{ output_dir } /stacks-cuda-{ name } .stacks" , "self_cuda_time_total"
737
739
)
738
740
741
+ if memory_snapshot :
742
+ torch .cuda .memory ._record_memory_history (
743
+ max_entries = MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
744
+ )
739
745
with torch .profiler .profile (
740
746
activities = [
741
747
torch .profiler .ProfilerActivity .CPU ,
@@ -757,6 +763,17 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
757
763
else :
758
764
torch .cuda .synchronize (rank )
759
765
766
+ if memory_snapshot :
767
+ try :
768
+ torch .cuda .memory ._dump_snapshot (
769
+ f"{ output_dir } /memory-{ name } -rank{ rank } .pickle"
770
+ )
771
+ except Exception as e :
772
+ logger .error (f"Failed to capture memory snapshot { e } " )
773
+
774
+ # Stop recording memory snapshot history.
775
+ torch .cuda .memory ._record_memory_history (enabled = None )
776
+
760
777
return BenchmarkResult (
761
778
short_name = name ,
762
779
gpu_elapsed_time = gpu_elapsed_time ,
@@ -831,6 +848,7 @@ class BenchFuncConfig:
831
848
pre_gpu_load : int = 0
832
849
export_stacks : bool = False
833
850
all_rank_traces : bool = False
851
+ memory_snapshot : bool = False
834
852
835
853
# pyre-ignore [2]
836
854
def benchmark_func_kwargs (self , ** kwargs_to_override ) -> Dict [str , Any ]:
@@ -844,6 +862,7 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
844
862
"pre_gpu_load" : self .pre_gpu_load ,
845
863
"export_stacks" : self .export_stacks ,
846
864
"all_rank_traces" : self .all_rank_traces ,
865
+ "memory_snapshot" : self .memory_snapshot ,
847
866
} | kwargs_to_override
848
867
849
868
@@ -862,6 +881,7 @@ def benchmark_func(
862
881
pre_gpu_load : int = 0 ,
863
882
export_stacks : bool = False ,
864
883
all_rank_traces : bool = False ,
884
+ memory_snapshot : bool = False ,
865
885
) -> BenchmarkResult :
866
886
"""
867
887
Args:
@@ -870,7 +890,7 @@ def benchmark_func(
870
890
stats. ``rank == -1`` means single-process mode.
871
891
872
892
func_to_benchmark: Callable that executes one measured iteration.
873
- func_to_benchmark(batch_inputs, **kwargs)
893
+ func_to_benchmark(batch_inputs, **kwargs) -> None
874
894
bench_inputs, prof_inputs: List[Dict[str, Any]] this argument will be fed
875
895
to the function at once, and bench_inputs will be used for benchmarking
876
896
while prof_inputs will be used for profiling
@@ -885,6 +905,7 @@ def benchmark_func(
885
905
measured iteration (helps simulating a loaded allocator).
886
906
export_stacks: Whether to export flamegraph-compatible stack files.
887
907
all_rank_traces: Whether to export traces from all ranks.
908
+ memory_snapshot: Whether to capture memory snapshot during the profiling
888
909
"""
889
910
if benchmark_func_kwargs is None :
890
911
benchmark_func_kwargs = {}
@@ -912,4 +933,5 @@ def _profile_iter_fn(prof: torch.profiler.profile) -> None:
912
933
export_stacks = export_stacks ,
913
934
reset_accumulated_memory_stats = True ,
914
935
all_rank_traces = all_rank_traces ,
936
+ memory_snapshot = memory_snapshot ,
915
937
)
0 commit comments