Skip to content

Commit 3605eff

Browse files
Fix memory measurement for bs=1
1 parent 1bdb608 commit 3605eff

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_benchmark/benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def benchmark(
344344
# Measure Allocated Memory
345345
if model_device.type == "cuda":
346346
pre_mem, post_mem, max_mem = measure_allocated_memory(
347-
model, sample, model_device, transfer_to_device_fn, print_details
347+
model, s, model_device, transfer_to_device_fn, print_details
348348
)
349349
memory[f"batch_size_{bs}"] = {
350350
"pre_inference_bytes": pre_mem,
@@ -386,7 +386,7 @@ def benchmark(
386386
model_device,
387387
transfer_to_device_fn,
388388
num_runs,
389-
batch_size,
389+
bs,
390390
)
391391
print_fn(
392392
fmt({f"Timing results (batch_size={bs})": timing[f"batch_size_{bs}"]})

0 commit comments

Comments
 (0)