Skip to content

Commit 36afdea

Browse files
committed
error handling and logging.
1 parent 31e34d5 commit 36afdea

File tree

1 file changed

+32
-14
lines changed

1 file changed

+32
-14
lines changed

benchmarks/benchmarking_utils.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,13 @@
88
import torch.utils.benchmark as benchmark
99

1010
from diffusers.models.modeling_utils import ModelMixin
11+
from diffusers.utils import logging
1112
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
1213

1314

15+
logger = logging.get_logger(__name__)
16+
17+
1418
def benchmark_fn(f, *args, **kwargs):
1519
t0 = benchmark.Timer(
1620
stmt="f(*args, **kwargs)",
@@ -27,6 +31,8 @@ def flush():
2731
torch.cuda.reset_peak_memory_stats()
2832

2933

34+
# Users can define their own in case this doesn't suffice. For most cases,
35+
# it should be sufficient.
3036
def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
3137
model = model_cls.from_pretrained(**init_kwargs).eval()
3238
if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
@@ -64,24 +70,35 @@ def post_benchmark(self, model):
6470
@torch.no_grad()
6571
def run_benchmark(self, scenario: BenchmarkScenario):
6672
# 1) plain stats
67-
plain = self._run_phase(
68-
model_cls=scenario.model_cls,
69-
init_fn=scenario.model_init_fn,
70-
init_kwargs=scenario.model_init_kwargs,
71-
get_input_fn=scenario.get_model_input_dict,
72-
compile_kwargs=None,
73-
)
74-
75-
# 2) compiled stats (if any)
76-
compiled = {"time": None, "memory": None}
77-
if scenario.compile_kwargs:
78-
compiled = self._run_phase(
73+
results = {}
74+
plain = None
75+
try:
76+
plain = self._run_phase(
7977
model_cls=scenario.model_cls,
8078
init_fn=scenario.model_init_fn,
8179
init_kwargs=scenario.model_init_kwargs,
8280
get_input_fn=scenario.get_model_input_dict,
83-
compile_kwargs=scenario.compile_kwargs,
81+
compile_kwargs=None,
8482
)
83+
except Exception as e:
84+
logger.error(f"Benchmark could not be run with the following error\n: {e}")
85+
return results
86+
87+
# 2) compiled stats (if any)
88+
compiled = {"time": None, "memory": None}
89+
if scenario.compile_kwargs:
90+
try:
91+
compiled = self._run_phase(
92+
model_cls=scenario.model_cls,
93+
init_fn=scenario.model_init_fn,
94+
init_kwargs=scenario.model_init_kwargs,
95+
get_input_fn=scenario.get_model_input_dict,
96+
compile_kwargs=scenario.compile_kwargs,
97+
)
98+
except Exception as e:
99+
logger.error(f"Compilation benchmark could not be run with the following error\n: {e}")
100+
if plain is None:
101+
return results
85102

86103
# 3) merge
87104
result = {
@@ -103,8 +120,9 @@ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[Ben
103120
if not isinstance(scenarios, list):
104121
scenarios = [scenarios]
105122
records = [self.run_benchmark(s) for s in scenarios]
106-
df = pd.DataFrame.from_records(records)
123+
df = pd.DataFrame.from_records([r for r in records if r])
107124
df.to_csv(filename, index=False)
125+
logger.info(f"Results serialized to {filename=}.")
108126

109127
def _run_phase(
110128
self,

0 commit comments

Comments
 (0)