88import torch .utils .benchmark as benchmark
99
1010from diffusers .models .modeling_utils import ModelMixin
11+ from diffusers .utils import logging
1112from diffusers .utils .testing_utils import require_torch_gpu , torch_device
1213
1314
15+ logger = logging .get_logger (__name__ )
16+
17+
1418def benchmark_fn (f , * args , ** kwargs ):
1519 t0 = benchmark .Timer (
1620 stmt = "f(*args, **kwargs)" ,
@@ -27,6 +31,8 @@ def flush():
2731 torch .cuda .reset_peak_memory_stats ()
2832
2933
34+ # Users can define their own in case this doesn't suffice. For most cases,
35+ # it should be sufficient.
3036def model_init_fn (model_cls , group_offload_kwargs = None , layerwise_upcasting = False , ** init_kwargs ):
3137 model = model_cls .from_pretrained (** init_kwargs ).eval ()
3238 if group_offload_kwargs and isinstance (group_offload_kwargs , dict ):
@@ -64,24 +70,35 @@ def post_benchmark(self, model):
6470 @torch .no_grad ()
6571 def run_benchmark (self , scenario : BenchmarkScenario ):
6672 # 1) plain stats
67- plain = self ._run_phase (
68- model_cls = scenario .model_cls ,
69- init_fn = scenario .model_init_fn ,
70- init_kwargs = scenario .model_init_kwargs ,
71- get_input_fn = scenario .get_model_input_dict ,
72- compile_kwargs = None ,
73- )
74-
75- # 2) compiled stats (if any)
76- compiled = {"time" : None , "memory" : None }
77- if scenario .compile_kwargs :
78- compiled = self ._run_phase (
73+ results = {}
74+ plain = None
75+ try :
76+ plain = self ._run_phase (
7977 model_cls = scenario .model_cls ,
8078 init_fn = scenario .model_init_fn ,
8179 init_kwargs = scenario .model_init_kwargs ,
8280 get_input_fn = scenario .get_model_input_dict ,
83- compile_kwargs = scenario . compile_kwargs ,
81+ compile_kwargs = None ,
8482 )
83+ except Exception as e :
84+ logger .error (f"Benchmark could not be run with the following error\n : { e } " )
85+ return results
86+
87+ # 2) compiled stats (if any)
88+ compiled = {"time" : None , "memory" : None }
89+ if scenario .compile_kwargs :
90+ try :
91+ compiled = self ._run_phase (
92+ model_cls = scenario .model_cls ,
93+ init_fn = scenario .model_init_fn ,
94+ init_kwargs = scenario .model_init_kwargs ,
95+ get_input_fn = scenario .get_model_input_dict ,
96+ compile_kwargs = scenario .compile_kwargs ,
97+ )
98+ except Exception as e :
99+ logger .error (f"Compilation benchmark could not be run with the following error\n : { e } " )
100+ if plain is None :
101+ return results
85102
86103 # 3) merge
87104 result = {
@@ -103,8 +120,9 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben
103120 if not isinstance (scenarios , list ):
104121 scenarios = [scenarios ]
105122 records = [self .run_benchmark (s ) for s in scenarios ]
106- df = pd .DataFrame .from_records (records )
123+ df = pd .DataFrame .from_records ([ r for r in records if r ] )
107124 df .to_csv (filename , index = False )
125+ logger .info (f"Results serialized to { filename = } ." )
108126
109127 def _run_phase (
110128 self ,
0 commit comments