4747}
4848
4949
50+ @torch .inference_mode ()
5051def generate_image (pipe , prompt , image_name ):
5152 seed = 42
5253 image = pipe (
@@ -59,56 +60,52 @@ def generate_image(pipe, prompt, image_name):
5960 print (f"Image generated saved as { image_name } " )
6061
6162
62- def benchmark_model (
63- pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 , model_dtype = "Half"
63+ @torch .inference_mode ()
64+ def benchmark_backbone_standalone (
65+ pipe , num_warmup = 10 , num_benchmark = 100 , model_name = "flux-dev" , model_dtype = "Half"
6466):
65- """Benchmark the backbone model inference time ."""
67+ """Benchmark the backbone model directly without running the full pipeline ."""
6668 backbone = pipe .transformer if hasattr (pipe , "transformer" ) else pipe .unet
6769
68- backbone_times = []
70+ # Generate dummy inputs for the backbone
71+ dummy_inputs , _ , _ = generate_dummy_inputs_and_dynamic_axes_and_shapes (model_name , backbone )
72+
73+ # Extract the dict from the tuple and move to cuda
74+ dummy_inputs_dict = {
75+ k : v .cuda () if isinstance (v , torch .Tensor ) else v for k , v in dummy_inputs [0 ].items ()
76+ }
77+
78+ # Warmup
79+ print (f"Warming up: { num_warmup } iterations" )
80+ for _ in tqdm (range (num_warmup ), desc = "Warmup" ):
81+ _ = backbone (** dummy_inputs_dict )
82+
83+ # Benchmark
84+ torch .cuda .synchronize ()
6985 start_event = torch .cuda .Event (enable_timing = True )
7086 end_event = torch .cuda .Event (enable_timing = True )
7187
72- def forward_pre_hook (_module , _input ):
88+ print (f"Benchmarking: { num_benchmark } iterations" )
89+ times = []
90+ for _ in tqdm (range (num_benchmark ), desc = "Benchmark" ):
7391 start_event .record ()
74-
75- def forward_hook (_module , _input , _output ):
92+ _ = backbone (** dummy_inputs_dict )
7693 end_event .record ()
7794 torch .cuda .synchronize ()
78- backbone_times .append (start_event .elapsed_time (end_event ))
79-
80- pre_handle = backbone .register_forward_pre_hook (forward_pre_hook )
81- post_handle = backbone .register_forward_hook (forward_hook )
82-
83- try :
84- print (f"Starting warmup: { num_warmup } runs" )
85- for _ in tqdm (range (num_warmup ), desc = "Warmup" ):
86- with torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
87- _ = pipe (
88- prompt ,
89- output_type = "pil" ,
90- num_inference_steps = num_inference_steps ,
91- generator = torch .Generator ("cuda" ).manual_seed (42 ),
92- )
93-
94- backbone_times .clear ()
95-
96- print (f"Starting benchmark: { num_runs } runs" )
97- for _ in tqdm (range (num_runs ), desc = "Benchmark" ):
98- with torch .amp .autocast ("cuda" , dtype = dtype_map [model_dtype ]):
99- _ = pipe (
100- prompt ,
101- output_type = "pil" ,
102- num_inference_steps = num_inference_steps ,
103- generator = torch .Generator ("cuda" ).manual_seed (42 ),
104- )
105- finally :
106- pre_handle .remove ()
107- post_handle .remove ()
108-
109- total_backbone_time = sum (backbone_times )
110- avg_latency = total_backbone_time / (num_runs * num_inference_steps )
111- print (f"Inference latency of the torch backbone: { avg_latency :.2f} ms" )
95+ times .append (start_event .elapsed_time (end_event ))
96+
97+ avg_latency = sum (times ) / len (times )
98+ times = sorted (times )
99+ p50 = times [len (times ) // 2 ]
100+ p95 = times [int (len (times ) * 0.95 )]
101+ p99 = times [int (len (times ) * 0.99 )]
102+
103+ print (f"\n Backbone-only inference latency ({ model_dtype } ):" )
104+ print (f" Average: { avg_latency :.2f} ms" )
105+ print (f" P50: { p50 :.2f} ms" )
106+ print (f" P95: { p95 :.2f} ms" )
107+ print (f" P99: { p99 :.2f} ms" )
108+
112109 return avg_latency
113110
114111
@@ -203,7 +200,13 @@ def main():
203200 pipe .to ("cuda" )
204201
205202 if args .benchmark :
206- benchmark_model (pipe , args .prompt , model_dtype = args .model_dtype )
203+ benchmark_backbone_standalone (
204+ pipe ,
205+ num_warmup = 10 ,
206+ num_benchmark = 100 ,
207+ model_name = args .model ,
208+ model_dtype = args .model_dtype ,
209+ )
207210
208211 if not args .skip_image :
209212 generate_image (pipe , args .prompt , image_name )
0 commit comments