Skip to content

Commit 4d83a47

Browse files
committed
add flops and params.
1 parent a2c03a4 commit 4d83a47

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

benchmarks/benchmarking_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import torch
88
import torch.utils.benchmark as benchmark
9+
from torchprofile import profile_macs
910

1011
from diffusers.models.modeling_utils import ModelMixin
1112
from diffusers.utils import logging
@@ -31,6 +32,19 @@ def flush():
3132
torch.cuda.reset_peak_memory_stats()
3233

3334

35+
# Taken from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
36+
def calculate_flops(model, input_dict):
37+
model.eval()
38+
with torch.no_grad():
39+
macs = profile_macs(model, **input_dict)
40+
flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
41+
return flops
42+
43+
44+
def calculate_params(model):
45+
return sum(p.numel() for p in model.parameters())
46+
47+
3448
# Users can define their own in case this doesn't suffice. For most cases,
3549
# it should be sufficient.
3650
def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
@@ -69,6 +83,14 @@ def post_benchmark(self, model):
6983

7084
@torch.no_grad()
7185
def run_benchmark(self, scenario: BenchmarkScenario):
86+
# 0) Basic stats
87+
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
88+
num_params = calculate_params(model)
89+
flops = calculate_flops(model, input_dict=scenario.model_init_kwargs)
90+
model.cpu()
91+
del model
92+
self.pre_benchmark()
93+
7294
# 1) plain stats
7395
results = {}
7496
plain = None
@@ -104,6 +126,8 @@ def run_benchmark(self, scenario: BenchmarkScenario):
104126
result = {
105127
"scenario": scenario.name,
106128
"model_cls": scenario.model_cls.__name__,
129+
"num_params": num_params,
130+
"flops": flops,
107131
"time_plain_s": plain["time"],
108132
"mem_plain_GB": plain["memory"],
109133
"time_compile_s": compiled["time"],

0 commit comments

Comments
 (0)