4949}
5050
5151
52+ @torch .inference_mode ()
5253def generate_image (pipe , prompt , image_name ):
5354 seed = 42
5455 image = pipe (
@@ -61,56 +62,52 @@ def generate_image(pipe, prompt, image_name):
6162 print (f"Image generated saved as { image_name } " )
6263
6364
64- def benchmark_model (
65- pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 , model_dtype = torch .float16
65+ @torch .inference_mode ()
66+ def benchmark_backbone_standalone (
67+ pipe , num_warmup = 10 , num_benchmark = 100 , model_name = "flux-dev" ,
6668):
67- """Benchmark the backbone model inference time ."""
69+ """Benchmark the backbone model directly without running the full pipeline ."""
6870 backbone = pipe .transformer if hasattr (pipe , "transformer" ) else pipe .unet
6971
70- backbone_times = []
72+ # Generate dummy inputs for the backbone
73+ dummy_inputs , _ , _ = generate_dummy_inputs_and_dynamic_axes_and_shapes (model_name , backbone )
74+
75+ # Extract the dict from the tuple and move to cuda
76+ dummy_inputs_dict = {
77+ k : v .cuda () if isinstance (v , torch .Tensor ) else v for k , v in dummy_inputs [0 ].items ()
78+ }
79+
80+ # Warmup
81+ print (f"Warming up: { num_warmup } iterations" )
82+ for _ in tqdm (range (num_warmup ), desc = "Warmup" ):
83+ _ = backbone (** dummy_inputs_dict )
84+
85+ # Benchmark
86+ torch .cuda .synchronize ()
7187 start_event = torch .cuda .Event (enable_timing = True )
7288 end_event = torch .cuda .Event (enable_timing = True )
7389
74- def forward_pre_hook (_module , _input ):
90+ print (f"Benchmarking: { num_benchmark } iterations" )
91+ times = []
92+ for _ in tqdm (range (num_benchmark ), desc = "Benchmark" ):
7593 start_event .record ()
76-
77- def forward_hook (_module , _input , _output ):
94+ _ = backbone (** dummy_inputs_dict )
7895 end_event .record ()
7996 torch .cuda .synchronize ()
80- backbone_times .append (start_event .elapsed_time (end_event ))
81-
82- pre_handle = backbone .register_forward_pre_hook (forward_pre_hook )
83- post_handle = backbone .register_forward_hook (forward_hook )
84-
85- try :
86- print (f"Starting warmup: { num_warmup } runs" )
87- for _ in tqdm (range (num_warmup ), desc = "Warmup" ):
88- with torch .amp .autocast ("cuda" , dtype = model_dtype ):
89- _ = pipe (
90- prompt ,
91- output_type = "pil" ,
92- num_inference_steps = num_inference_steps ,
93- generator = torch .Generator ("cuda" ).manual_seed (42 ),
94- )
95-
96- backbone_times .clear ()
97-
98- print (f"Starting benchmark: { num_runs } runs" )
99- for _ in tqdm (range (num_runs ), desc = "Benchmark" ):
100- with torch .amp .autocast ("cuda" , dtype = model_dtype ):
101- _ = pipe (
102- prompt ,
103- output_type = "pil" ,
104- num_inference_steps = num_inference_steps ,
105- generator = torch .Generator ("cuda" ).manual_seed (42 ),
106- )
107- finally :
108- pre_handle .remove ()
109- post_handle .remove ()
110-
111- total_backbone_time = sum (backbone_times )
112- avg_latency = total_backbone_time / (num_runs * num_inference_steps )
113- print (f"Inference latency of the torch backbone: { avg_latency :.2f} ms" )
97+ times .append (start_event .elapsed_time (end_event ))
98+
99+ avg_latency = sum (times ) / len (times )
100+ times = sorted (times )
101+ p50 = times [len (times ) // 2 ]
102+ p95 = times [int (len (times ) * 0.95 )]
103+ p99 = times [int (len (times ) * 0.99 )]
104+
105+ print (f"\n Backbone-only inference latency:" )
106+ print (f" Average: { avg_latency :.2f} ms" )
107+ print (f" P50: { p50 :.2f} ms" )
108+ print (f" P95: { p95 :.2f} ms" )
109+ print (f" P99: { p99 :.2f} ms" )
110+
114111 return avg_latency
115112
116113
@@ -196,7 +193,12 @@ def main():
196193 pipe .to ("cuda" )
197194
198195 if args .benchmark :
199- benchmark_model (pipe , args .prompt , model_dtype = model_dtype )
196+ benchmark_backbone_standalone (
197+ pipe ,
198+ num_warmup = 10 ,
199+ num_benchmark = 100 ,
200+ model_name = args .model ,
201+ )
200202
201203 if not args .skip_image :
202204 generate_image (pipe , args .prompt , image_name )
0 commit comments