@@ -1343,10 +1343,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13431343                revision = revision ,
13441344            )
13451345
1346-         allow_pickle  =  False 
1347-         if  use_safetensors  is  None :
1348-             use_safetensors  =  True 
1349-             allow_pickle  =  True 
1346+         allow_pickle  =  True  if  (use_safetensors  is  None  or  use_safetensors  is  False ) else  False 
1347+         use_safetensors  =  use_safetensors  if  use_safetensors  is  not None  else  True 
13501348
13511349        allow_patterns  =  None 
13521350        ignore_patterns  =  None 
@@ -1361,6 +1359,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13611359                model_info_call_error  =  e   # save error to reraise it if model is not cached locally 
13621360
13631361        if  not  local_files_only :
1362+             config_file  =  hf_hub_download (
1363+                 pretrained_model_name ,
1364+                 cls .config_name ,
1365+                 cache_dir = cache_dir ,
1366+                 revision = revision ,
1367+                 proxies = proxies ,
1368+                 force_download = force_download ,
1369+                 token = token ,
1370+             )
1371+             config_dict  =  cls ._dict_from_json_file (config_file )
1372+             ignore_filenames  =  config_dict .pop ("_ignore_files" , [])
1373+ 
13641374            filenames  =  {sibling .rfilename  for  sibling  in  info .siblings }
13651375            if  variant  is  not None  and  _check_legacy_sharding_variant_format (filenames = filenames , variant = variant ):
13661376                warn_msg  =  (
@@ -1375,61 +1385,20 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13751385                )
13761386                logger .warning (warn_msg )
13771387
1378-             model_filenames , variant_filenames  =  variant_compatible_siblings (
1379-                 filenames , variant = variant , use_safetensors = use_safetensors 
1380-             )
1381- 
1382-             config_file  =  hf_hub_download (
1383-                 pretrained_model_name ,
1384-                 cls .config_name ,
1385-                 cache_dir = cache_dir ,
1386-                 revision = revision ,
1387-                 proxies = proxies ,
1388-                 force_download = force_download ,
1389-                 token = token ,
1390-             )
1391- 
1392-             config_dict  =  cls ._dict_from_json_file (config_file )
1393-             ignore_filenames  =  config_dict .pop ("_ignore_files" , [])
1394- 
1395-             # remove ignored filenames 
1396-             model_filenames  =  set (model_filenames ) -  set (ignore_filenames )
1397-             variant_filenames  =  set (variant_filenames ) -  set (ignore_filenames )
1398- 
1388+             filenames  =  set (filenames ) -  set (ignore_filenames )
13991389            if  revision  in  DEPRECATED_REVISION_ARGS  and  version .parse (
14001390                version .parse (__version__ ).base_version 
14011391            ) >=  version .parse ("0.22.0" ):
1402-                 warn_deprecated_model_variant (pretrained_model_name , token , variant , revision , model_filenames )
1392+                 warn_deprecated_model_variant (pretrained_model_name , token , variant , revision , filenames )
14031393
14041394            custom_components , folder_names  =  _get_custom_components_and_folders (
1405-                 pretrained_model_name , config_dict , filenames , variant_filenames ,  variant 
1395+                 pretrained_model_name , config_dict , filenames , variant 
14061396            )
1407-             model_folder_names  =  {os .path .split (f )[0 ] for  f  in  model_filenames  if  os .path .split (f )[0 ] in  folder_names }
1408- 
14091397            custom_class_name  =  None 
14101398            if  custom_pipeline  is  None  and  isinstance (config_dict ["_class_name" ], (list , tuple )):
14111399                custom_pipeline  =  config_dict ["_class_name" ][0 ]
14121400                custom_class_name  =  config_dict ["_class_name" ][1 ]
14131401
1414-             # all filenames compatible with variant will be added 
1415-             allow_patterns  =  list (model_filenames )
1416- 
1417-             # allow all patterns from non-model folders 
1418-             # this enables downloading schedulers, tokenizers, ... 
1419-             allow_patterns  +=  [f"{ k }   for  k  in  folder_names  if  k  not  in model_folder_names ]
1420-             # add custom component files 
1421-             allow_patterns  +=  [f"{ k } { f }   for  k , f  in  custom_components .items ()]
1422-             # add custom pipeline file 
1423-             allow_patterns  +=  [f"{ custom_pipeline }  ] if  f"{ custom_pipeline }   in  filenames  else  []
1424-             # also allow downloading config.json files with the model 
1425-             allow_patterns  +=  [os .path .join (k , "config.json" ) for  k  in  model_folder_names ]
1426-             allow_patterns  +=  [
1427-                 SCHEDULER_CONFIG_NAME ,
1428-                 CONFIG_NAME ,
1429-                 cls .config_name ,
1430-                 CUSTOM_PIPELINE_FILE_NAME ,
1431-             ]
1432- 
14331402            load_pipe_from_hub  =  custom_pipeline  is  not None  and  f"{ custom_pipeline }   in  filenames 
14341403            load_components_from_hub  =  len (custom_components ) >  0 
14351404
@@ -1446,6 +1415,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14461415                    f"load the model. You can inspect the repository content at { ', ' .join ([f'https://hf.co/{ pretrained_model_name } { k } { v }   for  k ,v  in  custom_components .items ()])} \n " 
14471416                    f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." 
14481417                )
1418+             model_folder_names  =  {os .path .split (f )[0 ] for  f  in  filenames  if  os .path .split (f )[0 ] in  folder_names }
14491419
14501420            # retrieve passed components that should not be downloaded 
14511421            pipeline_class  =  _get_pipeline_class (
@@ -1466,8 +1436,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14661436            ignore_patterns  =  _get_ignore_patterns (
14671437                passed_components ,
14681438                model_folder_names ,
1469-                 model_filenames ,
1470-                 variant_filenames ,
1439+                 filenames ,
14711440                use_safetensors ,
14721441                from_flax ,
14731442                allow_pickle ,
@@ -1476,6 +1445,49 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14761445                variant ,
14771446            )
14781447
1448+             model_filenames , variant_filenames  =  variant_compatible_siblings (
1449+                 filenames , variant = variant , ignore_patterns = ignore_patterns 
1450+             )
1451+ 
1452+             safetensors_variant_filenames  =  {f  for  f  in  variant_filenames  if  f .endswith (".safetensors" )}
1453+             safetensors_model_filenames  =  {f  for  f  in  model_filenames  if  f .endswith (".safetensors" )}
1454+             if  len (safetensors_variant_filenames ) >  0  and  safetensors_model_filenames  !=  safetensors_variant_filenames :
1455+                 logger .warning (
1456+                     f"\n A mixture of { variant } { variant } \n Loaded { variant } \n " 
1457+                     f"[{ ', ' .join (safetensors_variant_filenames )} \n Loaded non-{ variant } \n " 
1458+                     f"[{ ', ' .join (safetensors_model_filenames  -  safetensors_variant_filenames )} \n If this behavior is not " 
1459+                     f"expected, please check your folder structure." 
1460+                 )
1461+ 
1462+             bin_variant_filenames  =  {f  for  f  in  variant_filenames  if  f .endswith (".bin" )}
1463+             bin_model_filenames  =  {f  for  f  in  model_filenames  if  f .endswith (".bin" )}
1464+             if  len (bin_variant_filenames ) >  0  and  bin_model_filenames  !=  bin_variant_filenames :
1465+                 logger .warning (
1466+                     f"\n A mixture of { variant } { variant } \n Loaded { variant } \n " 
1467+                     f"[{ ', ' .join (bin_variant_filenames )} \n Loaded non-{ variant } \n " 
1468+                     f"[{ ', ' .join (bin_model_filenames  -  bin_variant_filenames )} \n If this behavior is not expected, please check " 
1469+                     f"your folder structure." 
1470+                 )
1471+ 
1472+             # all filenames compatible with variant will be added 
1473+             allow_patterns  =  list (model_filenames )
1474+ 
1475+             # allow all patterns from non-model folders 
1476+             # this enables downloading schedulers, tokenizers, ... 
1477+             allow_patterns  +=  [f"{ k }   for  k  in  folder_names  if  k  not  in model_folder_names ]
1478+             # add custom component files 
1479+             allow_patterns  +=  [f"{ k } { f }   for  k , f  in  custom_components .items ()]
1480+             # add custom pipeline file 
1481+             allow_patterns  +=  [f"{ custom_pipeline }  ] if  f"{ custom_pipeline }   in  filenames  else  []
1482+             # also allow downloading config.json files with the model 
1483+             allow_patterns  +=  [os .path .join (k , "config.json" ) for  k  in  model_folder_names ]
1484+             allow_patterns  +=  [
1485+                 SCHEDULER_CONFIG_NAME ,
1486+                 CONFIG_NAME ,
1487+                 cls .config_name ,
1488+                 CUSTOM_PIPELINE_FILE_NAME ,
1489+             ]
1490+ 
14791491            # Don't download any objects that are passed 
14801492            allow_patterns  =  [
14811493                p  for  p  in  allow_patterns  if  not  (len (p .split ("/" )) ==  2  and  p .split ("/" )[0 ] in  passed_components )
0 commit comments