Skip to content

Commit ee0fcd4

Browse files
committed
quality improvements.
1 parent 61dd029 commit ee0fcd4

File tree

2 files changed

+26
-18
lines changed

2 files changed

+26
-18
lines changed

benchmarks/benchmarking_utils.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
import torch.utils.benchmark as benchmark
1010

1111
from diffusers.models.modeling_utils import ModelMixin
12+
from diffusers.utils import logging
1213
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
1314

1415

16+
logger = logging.get_logger(__name__)
17+
18+
1519
def benchmark_fn(f, *args, **kwargs):
1620
t0 = benchmark.Timer(
1721
stmt="f(*args, **kwargs)",
@@ -101,12 +105,16 @@ def post_benchmark(self, model):
101105
@torch.no_grad()
102106
def run_benchmark(self, scenario: BenchmarkScenario):
103107
# 0) Basic stats
104-
print(f"Running scenario: {scenario.name}.")
105-
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
106-
num_params = round(calculate_params(model) / 1e6, 2)
107-
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2)
108-
model.cpu()
109-
del model
108+
logger.info(f"Running scenario: {scenario.name}.")
109+
try:
110+
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
111+
num_params = round(calculate_params(model) / 1e6, 2)
112+
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e6, 2)
113+
model.cpu()
114+
del model
115+
except Exception as e:
116+
logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}")
117+
return {}
110118
self.pre_benchmark()
111119

112120
# 1) plain stats
@@ -121,7 +129,7 @@ def run_benchmark(self, scenario: BenchmarkScenario):
121129
compile_kwargs=None,
122130
)
123131
except Exception as e:
124-
print(f"Benchmark could not be run with the following error\n: {e}")
132+
logger.info(f"Benchmark could not be run with the following error:\n{e}")
125133
return results
126134

127135
# 2) compiled stats (if any)
@@ -136,7 +144,7 @@ def run_benchmark(self, scenario: BenchmarkScenario):
136144
compile_kwargs=scenario.compile_kwargs,
137145
)
138146
except Exception as e:
139-
print(f"Compilation benchmark could not be run with the following error\n: {e}")
147+
logger.info(f"Compilation benchmark could not be run with the following error\n: {e}")
140148
if plain is None:
141149
return results
142150

@@ -166,10 +174,10 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben
166174
try:
167175
records.append(self.run_benchmark(s))
168176
except Exception as e:
169-
print(f"Running scenario ({s.name}) led to error:\n{e}")
177+
logger.info(f"Running scenario ({s.name}) led to error:\n{e}")
170178
df = pd.DataFrame.from_records([r for r in records if r])
171179
df.to_csv(filename, index=False)
172-
print(f"Results serialized to {filename=}.")
180+
logger.info(f"Results serialized to {filename=}.")
173181

174182
def _run_phase(
175183
self,

benchmarks/run_all.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,16 @@ def run_command(command: list[str], return_stdout=False):
3434

3535
def run_scripts():
3636
python_files = sorted(glob.glob(PATTERN))
37+
python_files = [f for f in python_files if f != "benchmarking_utils.py"]
3738

3839
for file in python_files:
39-
if file != "benchmarking_utils.py":
40-
print(f"****** Running file: {file} ******")
41-
command = f"python {file}"
42-
try:
43-
run_command(command.split())
44-
except SubprocessCallException as e:
45-
print(f"Error running {file}: {e}")
46-
continue
40+
print(f"****** Running file: {file} ******")
41+
command = f"python {file}"
42+
try:
43+
run_command(command.split())
44+
except SubprocessCallException as e:
45+
print(f"Error running {file}:\n{e}")
46+
continue
4747

4848

4949
def merge_csvs():

0 commit comments

Comments
 (0)