11import gc
2+ from dataclasses import dataclass
3+ from typing import Any , Callable , Dict , Optional
24
35import torch
46import torch .utils .benchmark as benchmark
@@ -13,7 +15,7 @@ def benchmark_fn(f, *args, **kwargs):
1315 globals = {"args" : args , "kwargs" : kwargs , "f" : f },
1416 num_threads = 1 ,
1517 )
16- return f"{ (t0 .blocked_autorange ().mean ):.3f} "
18+ return float ( f"{ (t0 .blocked_autorange ().mean ):.3f} " )
1719
1820
1921def flush ():
@@ -23,11 +25,18 @@ def flush():
2325 torch .cuda .reset_peak_memory_stats ()
2426
2527
28+ @dataclass
29+ class BenchmarkScenario :
30+ name : str
31+ model_cls : ModelMixin
32+ model_init_kwargs : Dict [str , Any ]
33+ model_init_fn : Callable
34+ get_model_input_dict : Callable [[], Dict [str , Any ]]
35+ compile_kwargs : Optional [Dict [str , Any ]] = None
36+
37+
2638@require_torch_gpu
2739class BenchmarkMixin :
28- model_class : ModelMixin = None
29- compile_kwargs : dict = None
30-
3140 def get_model_init_dict (self ):
3241 raise NotImplementedError
3342
@@ -47,31 +56,61 @@ def post_benchmark(self, model):
4756 torch .compiler .reset ()
4857
4958 @torch .no_grad ()
50- def run_benchmark (self ):
59+ def run_benchmark (self , scenario : BenchmarkScenario ):
60+ # 1) plain stats
61+ plain = self ._run_phase (
62+ init_fn = scenario .model_init_fn ,
63+ init_kwargs = scenario .model_init_kwargs ,
64+ get_input_fn = scenario .get_model_input_dict ,
65+ compile_kwargs = None ,
66+ )
67+
68+ # 2) compiled stats (if any)
69+ compiled = None
70+ if scenario .compile_kwargs :
71+ compiled = self ._run_phase (
72+ init_fn = scenario .model_init_fn ,
73+ init_kwargs = scenario .model_init_kwargs ,
74+ get_input_fn = scenario .get_model_input_dict ,
75+ compile_kwargs = scenario .compile_kwargs ,
76+ )
77+
78+ # 3) merge
79+ result = {"scenario" : scenario .name , "time_plain_s" : plain ["time" ], "mem_plain_GB" : plain ["memory" ]}
80+ if compiled :
81+ result .update (
82+ {
83+ "time_compile_s" : compiled ["time" ],
84+ "mem_compile_GB" : compiled ["memory" ],
85+ }
86+ )
87+ return result
88+
89+ def _run_phase (
90+ self ,
91+ * ,
92+ init_fn : Callable [..., Any ],
93+ init_kwargs : Dict [str , Any ],
94+ get_input_fn : Callable [[], Dict [str , torch .Tensor ]],
95+ compile_kwargs : Optional [Dict [str , Any ]],
96+ ) -> Dict [str , float ]:
97+ # setup
5198 self .pre_benchmark ()
5299
53- model = self .initialize_model () # Takes care of device placement.
54- input_dict = self .get_input_dict () # Takes care of device placement.
55-
56- time = benchmark_fn (lambda model , input_dict : model (** input_dict ), model , input_dict )
57- memory = torch .cuda .max_memory_allocated () / (1024 ** 3 )
58- memory = float (f"{ memory :.2f} " )
59- non_compile_stats = {"time" : time , "memory" : memory }
100+ # init & (optional) compile
101+ model = init_fn (** init_kwargs )
102+ if compile_kwargs :
103+ model .compile (** compile_kwargs )
60104
61- self .post_benchmark (model )
62- del model
63- self .pre_benchmark ()
105+ # build inputs
106+ inp = get_input_fn ()
64107
65- compile_stats = None
66- if self .compile_kwargs is not None :
67- model = self .initialize_model ()
68- input_dict = self .get_input_dict ()
69- model .compile (** self .compile_kwargs )
70- time = benchmark_fn (lambda model , input_dict : model (** input_dict ), model , input_dict )
71- memory = torch .cuda .max_memory_allocated () / (1024 ** 3 )
72- memory = float (f"{ memory :.2f} " )
73- compile_stats = {"time" : time , "memory" : memory }
108+ # measure
109+ time_s = benchmark_fn (lambda m , d : m (** d ), model , inp )
110+ mem_gb = torch .cuda .max_memory_allocated () / (1024 ** 3 )
111+ mem_gb = round (mem_gb , 2 )
74112
113+ # teardown
75114 self .post_benchmark (model )
76115 del model
77- return non_compile_stats , compile_stats
116+ return { "time" : time_s , "memory" : mem_gb }
0 commit comments