66import pandas as pd
77import torch
88import torch .utils .benchmark as benchmark
9+ from torchprofile import profile_macs
910
1011from diffusers .models .modeling_utils import ModelMixin
1112from diffusers .utils import logging
@@ -31,6 +32,19 @@ def flush():
3132 torch .cuda .reset_peak_memory_stats ()
3233
3334
35+ # Taken from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
36+ def calculate_flops (model , input_dict ):
37+ model .eval ()
38+ with torch .no_grad ():
39+ macs = profile_macs (model , ** input_dict )
40+ flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
41+ return flops
42+
43+
44+ def calculate_params (model ):
45+ return sum (p .numel () for p in model .parameters ())
46+
47+
3448# Users can define their own in case this doesn't suffice. For most cases,
3549# it should be sufficient.
3650def model_init_fn (model_cls , group_offload_kwargs = None , layerwise_upcasting = False , ** init_kwargs ):
@@ -69,6 +83,14 @@ def post_benchmark(self, model):
6983
7084 @torch .no_grad ()
7185 def run_benchmark (self , scenario : BenchmarkScenario ):
86+ # 0) Basic stats
87+ model = model_init_fn (scenario .model_cls , ** scenario .model_init_kwargs )
88+ num_params = calculate_params (model )
89+ flops = calculate_flops (model , input_dict = scenario .model_init_kwargs )
90+ model .cpu ()
91+ del model
92+ self .pre_benchmark ()
93+
7294 # 1) plain stats
7395 results = {}
7496 plain = None
@@ -104,6 +126,8 @@ def run_benchmark(self, scenario: BenchmarkScenario):
104126 result = {
105127 "scenario" : scenario .name ,
106128 "model_cls" : scenario .model_cls .__name__ ,
129+ "num_params" : num_params ,
130+ "flops" : flops ,
107131 "time_plain_s" : plain ["time" ],
108132 "mem_plain_GB" : plain ["memory" ],
109133 "time_compile_s" : compiled ["time" ],
0 commit comments