99import torch .utils .benchmark as benchmark
1010
1111from diffusers .models .modeling_utils import ModelMixin
12+ from diffusers .utils import logging
1213from diffusers .utils .testing_utils import require_torch_gpu , torch_device
1314
1415
16+ logger = logging .get_logger (__name__ )
17+
18+
1519def benchmark_fn (f , * args , ** kwargs ):
1620 t0 = benchmark .Timer (
1721 stmt = "f(*args, **kwargs)" ,
@@ -101,12 +105,16 @@ def post_benchmark(self, model):
101105 @torch .no_grad ()
102106 def run_benchmark (self , scenario : BenchmarkScenario ):
103107 # 0) Basic stats
104- print (f"Running scenario: { scenario .name } ." )
105- model = model_init_fn (scenario .model_cls , ** scenario .model_init_kwargs )
106- num_params = round (calculate_params (model ) / 1e6 , 2 )
107- flops = round (calculate_flops (model , input_dict = scenario .get_model_input_dict ()) / 1e6 , 2 )
108- model .cpu ()
109- del model
108+ logger .info (f"Running scenario: { scenario .name } ." )
109+ try :
110+ model = model_init_fn (scenario .model_cls , ** scenario .model_init_kwargs )
111+ num_params = round (calculate_params (model ) / 1e6 , 2 )
112+ flops = round (calculate_flops (model , input_dict = scenario .get_model_input_dict ()) / 1e6 , 2 )
113+ model .cpu ()
114+ del model
115+ except Exception as e :
116+ logger .info (f"Error while initializing the model and calculating FLOPs:\n { e } " )
117+ return {}
110118 self .pre_benchmark ()
111119
112120 # 1) plain stats
@@ -121,7 +129,7 @@ def run_benchmark(self, scenario: BenchmarkScenario):
121129 compile_kwargs = None ,
122130 )
123131 except Exception as e :
124- print (f"Benchmark could not be run with the following error\n : { e } " )
132+ logger . info (f"Benchmark could not be run with the following error: \n { e } " )
125133 return results
126134
127135 # 2) compiled stats (if any)
@@ -136,7 +144,7 @@ def run_benchmark(self, scenario: BenchmarkScenario):
136144 compile_kwargs = scenario .compile_kwargs ,
137145 )
138146 except Exception as e :
139- print (f"Compilation benchmark could not be run with the following error\n : { e } " )
147+ logger . info (f"Compilation benchmark could not be run with the following error\n : { e } " )
140148 if plain is None :
141149 return results
142150
@@ -166,10 +174,10 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben
166174 try :
167175 records .append (self .run_benchmark (s ))
168176 except Exception as e :
169- print (f"Running scenario ({ s .name } ) led to error:\n { e } " )
177+ logger . info (f"Running scenario ({ s .name } ) led to error:\n { e } " )
170178 df = pd .DataFrame .from_records ([r for r in records if r ])
171179 df .to_csv (filename , index = False )
172- print (f"Results serialized to { filename = } ." )
180+ logger . info (f"Results serialized to { filename = } ." )
173181
174182 def _run_phase (
175183 self ,
0 commit comments