@@ -240,6 +240,9 @@ def main_export(
240240 loading_kwargs = model_loading_kwargs or {}
241241 if variant is not None :
242242 loading_kwargs ["variant" ] = variant
243+ dtype = loading_kwargs .get ("torch_dtype" , None )
244+ if isinstance (dtype , str ):
245+ dtype = getattr (torch , dtype ) if dtype != "auto" else dtype
243246 if library_name == "transformers" :
244247 config = AutoConfig .from_pretrained (
245248 model_name_or_path ,
@@ -302,9 +305,8 @@ def main_export(
302305 "Please provide custom export config if you want load model with remote code."
303306 )
304307 trust_remote_code = False
305- dtype = loading_kwargs .get ("torch_dtype" )
306- if isinstance (dtype , str ):
307- dtype = getattr (config , "torch_dtype" ) if dtype == "auto" else getattr (torch , dtype )
308+ if dtype == "auto" :
309+ dtype = getattr (config , "torch_dtype" )
308310
309311 if (
310312 dtype is None
@@ -351,19 +353,28 @@ class StoreAttr(object):
351353 GPTQQuantizer .post_init_model = post_init_model
352354 elif library_name == "diffusers" and is_openvino_version (">=" , "2024.6" ):
353355 _loading_kwargs = {} if variant is None else {"variant" : variant }
354- dtype = deduce_diffusers_dtype (
355- model_name_or_path ,
356- revision = revision ,
357- cache_dir = cache_dir ,
358- token = token ,
359- local_files_only = local_files_only ,
360- force_download = force_download ,
361- trust_remote_code = trust_remote_code ,
362- ** _loading_kwargs ,
363- )
356+ if dtype == "auto" or dtype is None :
357+ dtype = deduce_diffusers_dtype (
358+ model_name_or_path ,
359+ revision = revision ,
360+ cache_dir = cache_dir ,
361+ token = token ,
362+ local_files_only = local_files_only ,
363+ force_download = force_download ,
364+ trust_remote_code = trust_remote_code ,
365+ ** _loading_kwargs ,
366+ )
367+ if (
368+ dtype in {torch .bfloat16 , torch .float16 }
369+ and ov_config is not None
370+ and ov_config .dtype in {"fp16" , "fp32" }
371+ ):
372+ dtype = torch .float16 if ov_config .dtype == "fp16" else torch .float32
364373 if dtype in [torch .float16 , torch .bfloat16 ]:
365374 loading_kwargs ["torch_dtype" ] = dtype
366375 patch_16bit = True
376+ if loading_kwargs .get ("torch_dtype" ) == "auto" :
377+ loading_kwargs ["torch_dtype" ] = dtype
367378
368379 try :
369380 if library_name == "open_clip" :
0 commit comments