@@ -1291,7 +1291,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12911291 use_onnx = kwargs .pop ("use_onnx" , None )
12921292 load_connected_pipeline = kwargs .pop ("load_connected_pipeline" , False )
12931293 trust_remote_code = kwargs .pop ("trust_remote_code" , False )
1294- dduf_file = kwargs .pop ("dduf_file" , None )
1294+ dduf_file : Optional [ Dict [ str , DDUFEntry ]] = kwargs .pop ("dduf_file" , None )
12951295
12961296 allow_pickle = False
12971297 if use_safetensors is None :
@@ -1310,11 +1310,11 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13101310 local_files_only = True
13111311 model_info_call_error = e # save error to reraise it if model is not cached locally
13121312
1313- if dduf_file is not None and not local_files_only :
1314- dduf_available = False
1315- for sibling in info . siblings :
1316- dduf_available = dduf_file in sibling . rfilename
1317- if not dduf_available :
1313+ if (
1314+ not local_files_only
1315+ and dduf_file is not None
1316+ and dduf_file not in ( sibling . rfilename for sibling in info . siblings )
1317+ ):
13181318 raise ValueError (f"Requested { dduf_file } file is not available in { pretrained_model_name } ." )
13191319
13201320 if not local_files_only and not dduf_file :
@@ -1478,27 +1478,29 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14781478 )
14791479
14801480 # retrieve pipeline class from local file
1481- if not dduf_file :
1482- cls_name = cls .load_config (os .path .join (cached_folder , "model_index.json" )).get ("_class_name" , None )
1483- cls_name = cls_name [4 :] if isinstance (cls_name , str ) and cls_name .startswith ("Flax" ) else cls_name
1484-
1485- diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
1486- pipeline_class = getattr (diffusers_module , cls_name , None ) if isinstance (cls_name , str ) else None
1487-
1488- if pipeline_class is not None and pipeline_class ._load_connected_pipes :
1489- modelcard = ModelCard .load (os .path .join (cached_folder , "README.md" ))
1490- connected_pipes = sum ([getattr (modelcard .data , k , []) for k in CONNECTED_PIPES_KEYS ], [])
1491- for connected_pipe_repo_id in connected_pipes :
1492- download_kwargs = {
1493- "cache_dir" : cache_dir ,
1494- "force_download" : force_download ,
1495- "proxies" : proxies ,
1496- "local_files_only" : local_files_only ,
1497- "token" : token ,
1498- "variant" : variant ,
1499- "use_safetensors" : use_safetensors ,
1500- }
1501- DiffusionPipeline .download (connected_pipe_repo_id , ** download_kwargs )
1481+ if dduf_file :
1482+ return cached_folder
1483+
1484+ cls_name = cls .load_config (os .path .join (cached_folder , "model_index.json" )).get ("_class_name" , None )
1485+ cls_name = cls_name [4 :] if isinstance (cls_name , str ) and cls_name .startswith ("Flax" ) else cls_name
1486+
1487+ diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
1488+ pipeline_class = getattr (diffusers_module , cls_name , None ) if isinstance (cls_name , str ) else None
1489+
1490+ if pipeline_class is not None and pipeline_class ._load_connected_pipes :
1491+ modelcard = ModelCard .load (os .path .join (cached_folder , "README.md" ))
1492+ connected_pipes = sum ([getattr (modelcard .data , k , []) for k in CONNECTED_PIPES_KEYS ], [])
1493+ for connected_pipe_repo_id in connected_pipes :
1494+ download_kwargs = {
1495+ "cache_dir" : cache_dir ,
1496+ "force_download" : force_download ,
1497+ "proxies" : proxies ,
1498+ "local_files_only" : local_files_only ,
1499+ "token" : token ,
1500+ "variant" : variant ,
1501+ "use_safetensors" : use_safetensors ,
1502+ }
1503+ DiffusionPipeline .download (connected_pipe_repo_id , ** download_kwargs )
15021504
15031505 return cached_folder
15041506
0 commit comments