11import resource
2+ import shutil
23import sys
34import typing as t
45from contextlib import contextmanager
89class Profiler :
910 def __init__ (self , gpu : bool = False ):
1011 self ._errors : dict [str , str ] = {}
12+ self ._disk : dict [str , int ] = {"start" : get_disk_usage ()}
1113 self ._ram : dict [str , int ] = {"start" : get_peak_rss ()}
1214 self ._gpu : dict [str , list [dict [str , t .Any ]]] = {"start" : get_gpu_usage ()} if gpu else {}
1315 self ._imports_at_start = get_current_imports ()
@@ -18,6 +20,9 @@ def track_memory(self, event: str) -> None:
1820 if self ._gpu :
1921 self ._gpu [event ] = get_gpu_usage ()
2022
23+ def track_disk (self , event : str ) -> None :
24+ self ._disk [event ] = get_disk_usage ()
25+
2126 def track_error (self , event : str , error : str ) -> None :
2227 self ._errors [event ] = error
2328
@@ -30,6 +35,7 @@ def as_dict(self) -> dict[str, t.Any]:
3035
3136 as_dict : dict [str , t .Any ] = {
3237 "ram" : self ._ram ,
38+ "disk" : self ._disk ,
3339 "errors" : self ._errors ,
3440 "extra" : {"imports" : imported },
3541 } | self ._additionals
@@ -63,6 +69,14 @@ def capture_output() -> t.Generator[tuple[StringIO, StringIO], None, None]:
6369 sys .stderr = old_stderr
6470
6571
72+ def get_disk_usage () -> int :
73+ """
74+ Get the disk usage.
75+ """
76+ _ , used , _ = shutil .disk_usage ("/" )
77+ return used
78+
79+
6680def get_peak_rss () -> int :
6781 """
6882 Get the peak RSS memory usage of the current process.
0 commit comments