File tree Expand file tree Collapse file tree 1 file changed +10
-3
lines changed
examples/diffusers/quantization Expand file tree Collapse file tree 1 file changed +10
-3
lines changed Original file line number Diff line number Diff line change 4040 "flux-schnell" : ModelType .FLUX_SCHNELL ,
4141}
4242
43+ DTYPE_MAP = {
44+ "sdxl-1.0" : torch .float16 ,
45+ "sdxl-turbo" : torch .float16 ,
46+ "sd3-medium" : torch .float16 ,
47+ "flux-dev" : torch .bfloat16 ,
48+ "flux-schnell" : torch .bfloat16 ,
49+ }
50+
4351
4452def generate_image (pipe , prompt , image_name ):
4553 seed = 42
@@ -154,9 +162,11 @@ def main():
154162 args = parser .parse_args ()
155163
156164 image_name = args .save_image_as if args .save_image_as else f"{ args .model } .png"
165+ model_dtype = DTYPE_MAP [args .model ]
157166
158167 pipe = PipelineManager .create_pipeline_from (
159168 MODEL_ID [args .model ],
169+ torch_dtype = model_dtype ,
160170 override_model_path = args .override_model_path ,
161171 )
162172
@@ -171,9 +181,6 @@ def main():
171181 else :
172182 raise ValueError ("Pipeline does not have a transformer or unet backbone" )
173183
174- # Get dtype directly from the backbone's parameters
175- model_dtype = next (backbone .parameters ()).dtype
176-
177184 if args .restore_from :
178185 mto .restore (backbone , args .restore_from )
179186
You can’t perform that action at this time.
0 commit comments