Skip to content

Commit 324db38

Browse files
Fix FLOPs measurement error on CUDA
1 parent f44a872 commit 324db38

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ From v1.0.0 and on, the project will adherence strictly to Semantic Versioning.
1212
## [0.3.5] - 2023-08-08
1313
### Fixed
1414
- Updated on-gpu model benchmaking with best-practices on `cuda.Event` and `cuda.synchronize`.
15+
- FLOPs measurement error on CUDA.
1516

1617

1718
## [0.3.4] - 2022-02-22

pytorch_benchmark/benchmark.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,11 @@ def benchmark(
343343
batch_size=1,
344344
)
345345

346-
flops = measure_flops(model, sample1, print_details)
346+
with torch.no_grad():
347+
flops = measure_flops(
348+
model, transfer_to_device_fn(sample1, model_device), print_details
349+
)
350+
347351
if _is_valid(flops):
348352
results["flops"] = flops
349353
print_fn(f"Model FLOPs: {flops} ({format_num(flops)})")

0 commit comments

Comments
 (0)