4040 "flux-schnell" : ModelType .FLUX_SCHNELL ,
4141}
4242
43- dtype_map = {
44- "Half" : torch .float16 ,
45- "BFloat16" : torch .bfloat16 ,
46- "Float" : torch .float32 ,
47- }
48-
4943
5044def generate_image (pipe , prompt , image_name ):
5145 seed = 42
@@ -60,7 +54,7 @@ def generate_image(pipe, prompt, image_name):
6054
6155
6256def benchmark_model (
63- pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 , model_dtype = "Half"
57+ pipe , prompt , num_warmup = 10 , num_runs = 50 , num_inference_steps = 20 , model_dtype = torch . float16
6458):
6559 """Benchmark the backbone model inference time."""
6660 backbone = pipe .transformer if hasattr (pipe , "transformer" ) else pipe .unet
@@ -83,7 +77,7 @@ def forward_hook(_module, _input, _output):
8377 try :
8478 print (f"Starting warmup: { num_warmup } runs" )
8579 for _ in tqdm (range (num_warmup ), desc = "Warmup" ):
86- with torch .amp .autocast ("cuda" , dtype = dtype_map [ model_dtype ] ):
80+ with torch .amp .autocast ("cuda" , dtype = model_dtype ):
8781 _ = pipe (
8882 prompt ,
8983 output_type = "pil" ,
@@ -95,7 +89,7 @@ def forward_hook(_module, _input, _output):
9589
9690 print (f"Starting benchmark: { num_runs } runs" )
9791 for _ in tqdm (range (num_runs ), desc = "Benchmark" ):
98- with torch .amp .autocast ("cuda" , dtype = dtype_map [ model_dtype ] ):
92+ with torch .amp .autocast ("cuda" , dtype = model_dtype ):
9993 _ = pipe (
10094 prompt ,
10195 output_type = "pil" ,
@@ -169,17 +163,17 @@ def main():
169163 # Save the backbone of the pipeline and move it to the GPU
170164 add_embedding = None
171165 backbone = None
172- model_dtype = None
173166 if hasattr (pipe , "transformer" ):
174167 backbone = pipe .transformer
175- model_dtype = "Bfloat16"
176168 elif hasattr (pipe , "unet" ):
177169 backbone = pipe .unet
178170 add_embedding = backbone .add_embedding
179- model_dtype = "Half"
180171 else :
181172 raise ValueError ("Pipeline does not have a transformer or unet backbone" )
182173
174+ # Get dtype directly from the backbone's parameters
175+ model_dtype = next (backbone .parameters ()).dtype
176+
183177 if args .restore_from :
184178 mto .restore (backbone , args .restore_from )
185179
0 commit comments