@@ -126,13 +126,6 @@ def main():
126126 default = None ,
127127 help = "Path to the model if not using default paths in MODEL_ID mapping." ,
128128 )
129- parser .add_argument (
130- "--model-dtype" ,
131- type = str ,
132- default = "Half" ,
133- choices = ["Half" , "BFloat16" , "Float" ],
134- help = "Precision used to load the model." ,
135- )
136129 parser .add_argument (
137130 "--restore-from" , type = str , default = None , help = "Path to the modelopt quantized checkpoint"
138131 )
@@ -170,28 +163,27 @@ def main():
170163
171164 pipe = PipelineManager .create_pipeline_from (
172165 MODEL_ID [args .model ],
173- dtype_map [args .model_dtype ],
174166 override_model_path = args .override_model_path ,
175167 )
176168
177169 # Save the backbone of the pipeline and move it to the GPU
178170 add_embedding = None
179171 backbone = None
172+ model_dtype = None
180173 if hasattr (pipe , "transformer" ):
181174 backbone = pipe .transformer
175+ model_dtype = "Bfloat16"
182176 elif hasattr (pipe , "unet" ):
183177 backbone = pipe .unet
184178 add_embedding = backbone .add_embedding
179+ model_dtype = "Half"
185180 else :
186181 raise ValueError ("Pipeline does not have a transformer or unet backbone" )
187182
188183 if args .restore_from :
189184 mto .restore (backbone , args .restore_from )
190185
191186 if args .torch_compile :
192- assert args .model_dtype in ["BFloat16" , "Float" , "Half" ], (
193- "torch.compile() only supports BFloat16 and Float"
194- )
195187 print ("Compiling backbone with torch.compile()..." )
196188 backbone = torch .compile (backbone , mode = "max-autotune" )
197189
@@ -203,7 +195,7 @@ def main():
203195 pipe .to ("cuda" )
204196
205197 if args .benchmark :
206- benchmark_model (pipe , args .prompt , model_dtype = args . model_dtype )
198+ benchmark_model (pipe , args .prompt , model_dtype = model_dtype )
207199
208200 if not args .skip_image :
209201 generate_image (pipe , args .prompt , image_name )
0 commit comments