@@ -60,30 +60,47 @@ def generate_image(pipe, prompt, image_name):
6060
6161
6262def benchmark_model (pipe , prompt , num_warmup = 3 , num_runs = 10 ):
63- """Benchmark the model inference time."""
64- # Warmup runs
65- for _ in range (num_warmup ):
66- _ = pipe (
67- prompt ,
68- output_type = "pil" ,
69- num_inference_steps = 30 ,
70- generator = torch .Generator ("cuda" ).manual_seed (42 ),
71- )
63+ """Benchmark the backbone model inference time."""
64+ backbone = pipe .transformer if hasattr (pipe , "transformer" ) else pipe .unet
7265
73- # Benchmark runs
74- torch .cuda .synchronize ()
75- start = time .time ()
76- for _ in range (num_runs ):
77- _ = pipe (
78- prompt ,
79- output_type = "pil" ,
80- num_inference_steps = 30 ,
81- generator = torch .Generator ("cuda" ).manual_seed (42 ),
82- )
66+ backbone_times = []
67+
68+ def forward_pre_hook (module , input ):
8369 torch .cuda .synchronize ()
84- end = time .time ()
70+ module . _start_time = time .time ()
8571
86- avg_latency = (end - start ) / num_runs * 1000 # Convert to ms
72+ def forward_hook (module , input , output ):
73+ torch .cuda .synchronize ()
74+ module ._end_time = time .time ()
75+ backbone_times .append ((module ._end_time - module ._start_time ) * 1000 ) # Convert to ms
76+
77+ pre_handle = backbone .register_forward_pre_hook (forward_pre_hook )
78+ post_handle = backbone .register_forward_hook (forward_hook )
79+
80+ try :
81+ for _ in range (num_warmup ):
82+ _ = pipe (
83+ prompt ,
84+ output_type = "pil" ,
85+ num_inference_steps = 10 ,
86+ generator = torch .Generator ("cuda" ).manual_seed (42 ),
87+ )
88+
89+ backbone_times .clear ()
90+
91+ for _ in range (num_runs ):
92+ _ = pipe (
93+ prompt ,
94+ output_type = "pil" ,
95+ num_inference_steps = 10 ,
96+ generator = torch .Generator ("cuda" ).manual_seed (42 ),
97+ )
98+ finally :
99+ pre_handle .remove ()
100+ post_handle .remove ()
101+
102+ total_backbone_time = sum (backbone_times )
103+ avg_latency = total_backbone_time / num_runs
87104 return avg_latency
88105
89106
0 commit comments