Skip to content

Commit 3727ef8

Browse files
Updated on-gpu benchmaking with cuda event & sync
1 parent 426f81d commit 3727ef8

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
99

1010
## [Unreleased]
1111

12+
## [0.3.5] - 2023-08-08
13+
### Fixed
14+
- Updated on-gpu model benchmaking with best-practices on `cuda.Event` and `cuda.synchronize`.
15+
16+
1217
## [0.3.4] - 2022-02-22
1318

1419
### Fixed

pytorch_benchmark/benchmark.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,29 @@ def measure_repeated_inference_timing(
146146
):
147147
start_on_cpu = time()
148148
device_sample = transfer_to_device_fn(sample, model_device)
149-
start_on_device = time()
149+
150+
if model_device.type == "cuda":
151+
start_event = torch.cuda.Event(enable_timing=True)
152+
stop_event = torch.cuda.Event(enable_timing=True)
153+
start_event.record() # For GPU timing
154+
start_on_device = time() # For CPU timing
155+
150156
device_result = model(device_sample)
151-
stop_on_device = time()
157+
158+
if model_device.type == "cuda":
159+
stop_event.record()
160+
torch.cuda.synchronize()
161+
elapsed_on_device = stop_event.elapsed_time(start_event)
162+
stop_on_device = time()
163+
else:
164+
stop_on_device = time()
165+
elapsed_on_device = stop_on_device - start_on_device
166+
152167
transfer_to_device_fn(device_result, "cpu")
153168
stop_on_cpu = time()
154169

155170
t_c2d.append(start_on_device - start_on_cpu)
156-
t_inf.append(stop_on_device - start_on_device)
171+
t_inf.append(elapsed_on_device)
157172
t_d2c.append(stop_on_cpu - stop_on_device)
158173
t_tot.append(stop_on_cpu - start_on_cpu)
159174

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def from_file(file_name: str = "requirements.txt", comment_char: str = "#"):
2525

2626
setup(
2727
name="pytorch-benchmark",
28-
version="0.3.4",
28+
version="0.3.5",
2929
description="Easily benchmark PyTorch model FLOPs, latency, throughput, max allocated memory and energy consumption in one go.",
3030
long_description=long_description(),
3131
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)