Skip to content

Commit 169f831

Browse files
committed
checking.
1 parent cc0a38a commit 169f831

File tree

1 file changed

+64
-25
lines changed

1 file changed

+64
-25
lines changed

benchmarks/benchmarking_utils.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import gc
2+
from dataclasses import dataclass
3+
from typing import Any, Callable, Dict, Optional
24

35
import torch
46
import torch.utils.benchmark as benchmark
@@ -13,7 +15,7 @@ def benchmark_fn(f, *args, **kwargs):
1315
globals={"args": args, "kwargs": kwargs, "f": f},
1416
num_threads=1,
1517
)
16-
return f"{(t0.blocked_autorange().mean):.3f}"
18+
return float(f"{(t0.blocked_autorange().mean):.3f}")
1719

1820

1921
def flush():
@@ -23,11 +25,18 @@ def flush():
2325
torch.cuda.reset_peak_memory_stats()
2426

2527

28+
@dataclass
29+
class BenchmarkScenario:
30+
name: str
31+
model_cls: ModelMixin
32+
model_init_kwargs: Dict[str, Any]
33+
model_init_fn: Callable
34+
get_model_input_dict: Callable[[], Dict[str, Any]]
35+
compile_kwargs: Optional[Dict[str, Any]] = None
36+
37+
2638
@require_torch_gpu
2739
class BenchmarkMixin:
28-
model_class: ModelMixin = None
29-
compile_kwargs: dict = None
30-
3140
def get_model_init_dict(self):
3241
raise NotImplementedError
3342

@@ -47,31 +56,61 @@ def post_benchmark(self, model):
4756
torch.compiler.reset()
4857

4958
@torch.no_grad()
50-
def run_benchmark(self):
59+
def run_benchmark(self, scenario: BenchmarkScenario):
60+
# 1) plain stats
61+
plain = self._run_phase(
62+
init_fn=scenario.model_init_fn,
63+
init_kwargs=scenario.model_init_kwargs,
64+
get_input_fn=scenario.get_model_input_dict,
65+
compile_kwargs=None,
66+
)
67+
68+
# 2) compiled stats (if any)
69+
compiled = None
70+
if scenario.compile_kwargs:
71+
compiled = self._run_phase(
72+
init_fn=scenario.model_init_fn,
73+
init_kwargs=scenario.model_init_kwargs,
74+
get_input_fn=scenario.get_model_input_dict,
75+
compile_kwargs=scenario.compile_kwargs,
76+
)
77+
78+
# 3) merge
79+
result = {"scenario": scenario.name, "time_plain_s": plain["time"], "mem_plain_GB": plain["memory"]}
80+
if compiled:
81+
result.update(
82+
{
83+
"time_compile_s": compiled["time"],
84+
"mem_compile_GB": compiled["memory"],
85+
}
86+
)
87+
return result
88+
89+
def _run_phase(
90+
self,
91+
*,
92+
init_fn: Callable[..., Any],
93+
init_kwargs: Dict[str, Any],
94+
get_input_fn: Callable[[], Dict[str, torch.Tensor]],
95+
compile_kwargs: Optional[Dict[str, Any]],
96+
) -> Dict[str, float]:
97+
# setup
5198
self.pre_benchmark()
5299

53-
model = self.initialize_model() # Takes care of device placement.
54-
input_dict = self.get_input_dict() # Takes care of device placement.
55-
56-
time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict)
57-
memory = torch.cuda.max_memory_allocated() / (1024**3)
58-
memory = float(f"{memory:.2f}")
59-
non_compile_stats = {"time": time, "memory": memory}
100+
# init & (optional) compile
101+
model = init_fn(**init_kwargs)
102+
if compile_kwargs:
103+
model.compile(**compile_kwargs)
60104

61-
self.post_benchmark(model)
62-
del model
63-
self.pre_benchmark()
105+
# build inputs
106+
inp = get_input_fn()
64107

65-
compile_stats = None
66-
if self.compile_kwargs is not None:
67-
model = self.initialize_model()
68-
input_dict = self.get_input_dict()
69-
model.compile(**self.compile_kwargs)
70-
time = benchmark_fn(lambda model, input_dict: model(**input_dict), model, input_dict)
71-
memory = torch.cuda.max_memory_allocated() / (1024**3)
72-
memory = float(f"{memory:.2f}")
73-
compile_stats = {"time": time, "memory": memory}
108+
# measure
109+
time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
110+
mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
111+
mem_gb = round(mem_gb, 2)
74112

113+
# teardown
75114
self.post_benchmark(model)
76115
del model
77-
return non_compile_stats, compile_stats
116+
return {"time": time_s, "memory": mem_gb}

0 commit comments

Comments
 (0)