7474 CONNECTED_PIPES_KEYS ,
7575 CUSTOM_PIPELINE_FILE_NAME ,
7676 LOADABLE_CLASSES ,
77+ _download_dduf_file ,
7778 _fetch_class_library_tuple ,
7879 _get_custom_components_and_folders ,
7980 _get_custom_pipeline_class ,
8081 _get_final_device_map ,
8182 _get_ignore_patterns ,
8283 _get_pipeline_class ,
8384 _identify_model_variants ,
85+ _maybe_raise_error_for_incorrect_transformers ,
8486 _maybe_raise_warning_for_inpainting ,
8587 _resolve_custom_pipeline_and_cls ,
8688 _unwrap_model ,
@@ -728,8 +730,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
728730 " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
729731 )
730732
731- if dduf_file and custom_pipeline :
732- raise NotImplementedError ("Custom pipelines are not supported with DDUF at the moment." )
733+ if dduf_file :
734+ if custom_pipeline :
735+ raise NotImplementedError ("Custom pipelines are not supported with DDUF at the moment." )
736+ if load_connected_pipeline :
737+ raise NotImplementedError ("Connected pipelines are not supported with DDUF at the moment." )
733738
734739 # 1. Download the checkpoints and configs
735740 # use snapshot download here to get it working from from_pretrained
@@ -785,14 +790,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
785790 config_dict = cls .load_config (cached_folder , dduf_entries = dduf_entries )
786791
787792 if dduf_file :
788- has_transformers_component = False
789- for k in config_dict :
790- if isinstance (config_dict [k ], list ):
791- has_transformers_component = config_dict [k ][0 ] == "transformers"
792- if has_transformers_component :
793- break
794- if has_transformers_component and not is_transformers_version (">" , "4.47.1" ):
795- raise ValueError ("Please upgrade your `transformers` installation to the latest version to use DDUF." )
793+ _maybe_raise_error_for_incorrect_transformers (config_dict )
796794
797795 # pop out "_ignore_files" as it is only needed for download
798796 config_dict .pop ("_ignore_files" , None )
@@ -1328,8 +1326,21 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13281326 trust_remote_code = kwargs .pop ("trust_remote_code" , False )
13291327 dduf_file : Optional [Dict [str , DDUFEntry ]] = kwargs .pop ("dduf_file" , None )
13301328
1331- if dduf_file and custom_pipeline :
1332- raise NotImplementedError ("Custom pipelines are not supported with DDUF at the moment." )
1329+ if dduf_file :
1330+ if custom_pipeline :
1331+ raise NotImplementedError ("Custom pipelines are not supported with DDUF at the moment." )
1332+ if load_connected_pipeline :
1333+ raise NotImplementedError ("Connected pipelines are not supported with DDUF at the moment." )
1334+ return _download_dduf_file (
1335+ pretrained_model_name = pretrained_model_name ,
1336+ dduf_file = dduf_file ,
1337+ pipeline_class_name = cls .__name__ ,
1338+ cache_dir = cache_dir ,
1339+ proxies = proxies ,
1340+ local_files_only = local_files_only ,
1341+ token = token ,
1342+ revision = revision ,
1343+ )
13331344
13341345 allow_pickle = False
13351346 if use_safetensors is None :
@@ -1348,14 +1359,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13481359 local_files_only = True
13491360 model_info_call_error = e # save error to reraise it if model is not cached locally
13501361
1351- if (
1352- not local_files_only
1353- and dduf_file is not None
1354- and dduf_file not in (sibling .rfilename for sibling in info .siblings )
1355- ):
1356- raise ValueError (f"Requested { dduf_file } file is not available in { pretrained_model_name } ." )
1357-
1358- if not local_files_only and not dduf_file :
1362+ if not local_files_only :
13591363 filenames = {sibling .rfilename for sibling in info .siblings }
13601364 if variant is not None and _check_legacy_sharding_variant_format (filenames = filenames , variant = variant ):
13611365 warn_msg = (
@@ -1498,10 +1502,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14981502 user_agent ["custom_pipeline" ] = custom_pipeline
14991503
15001504 # download all allow_patterns - ignore_patterns
1501- # also allow downloading the dduf_file
1502- if dduf_file is not None :
1503- allow_patterns = [dduf_file ]
1504- ignore_patterns = []
15051505 try :
15061506 cached_folder = snapshot_download (
15071507 pretrained_model_name ,
@@ -1515,10 +1515,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
15151515 user_agent = user_agent ,
15161516 )
15171517
1518- # retrieve pipeline class from local file
1519- if dduf_file :
1520- return cached_folder
1521-
15221518 cls_name = cls .load_config (os .path .join (cached_folder , "model_index.json" )).get ("_class_name" , None )
15231519 cls_name = cls_name [4 :] if isinstance (cls_name , str ) and cls_name .startswith ("Flax" ) else cls_name
15241520
0 commit comments