@@ -1290,7 +1290,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12901290 use_onnx = kwargs .pop ("use_onnx" , None )
12911291 load_connected_pipeline = kwargs .pop ("load_connected_pipeline" , False )
12921292 trust_remote_code = kwargs .pop ("trust_remote_code" , False )
1293- dduf_file = kwargs .pop ("dduf_file" , None )
1293+ dduf_file : Optional [ Dict [ str , DDUFEntry ]] = kwargs .pop ("dduf_file" , None )
12941294
12951295 allow_pickle = False
12961296 if use_safetensors is None :
@@ -1309,11 +1309,11 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13091309 local_files_only = True
13101310 model_info_call_error = e # save error to reraise it if model is not cached locally
13111311
1312- if dduf_file is not None and not local_files_only :
1313- dduf_available = False
1314- for sibling in info . siblings :
1315- dduf_available = dduf_file in sibling . rfilename
1316- if not dduf_available :
1312+ if (
1313+ not local_files_only
1314+ and dduf_file is not None
1315+ and dduf_file not in ( sibling . rfilename for sibling in info . siblings )
1316+ ):
13171317 raise ValueError (f"Requested { dduf_file } file is not available in { pretrained_model_name } ." )
13181318
13191319 if not local_files_only and not dduf_file :
@@ -1477,27 +1477,29 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
14771477 )
14781478
14791479 # retrieve pipeline class from local file
1480- if not dduf_file :
1481- cls_name = cls .load_config (os .path .join (cached_folder , "model_index.json" )).get ("_class_name" , None )
1482- cls_name = cls_name [4 :] if isinstance (cls_name , str ) and cls_name .startswith ("Flax" ) else cls_name
1483-
1484- diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
1485- pipeline_class = getattr (diffusers_module , cls_name , None ) if isinstance (cls_name , str ) else None
1486-
1487- if pipeline_class is not None and pipeline_class ._load_connected_pipes :
1488- modelcard = ModelCard .load (os .path .join (cached_folder , "README.md" ))
1489- connected_pipes = sum ([getattr (modelcard .data , k , []) for k in CONNECTED_PIPES_KEYS ], [])
1490- for connected_pipe_repo_id in connected_pipes :
1491- download_kwargs = {
1492- "cache_dir" : cache_dir ,
1493- "force_download" : force_download ,
1494- "proxies" : proxies ,
1495- "local_files_only" : local_files_only ,
1496- "token" : token ,
1497- "variant" : variant ,
1498- "use_safetensors" : use_safetensors ,
1499- }
1500- DiffusionPipeline .download (connected_pipe_repo_id , ** download_kwargs )
1480+ if dduf_file :
1481+ return cached_folder
1482+
1483+ cls_name = cls .load_config (os .path .join (cached_folder , "model_index.json" )).get ("_class_name" , None )
1484+ cls_name = cls_name [4 :] if isinstance (cls_name , str ) and cls_name .startswith ("Flax" ) else cls_name
1485+
1486+ diffusers_module = importlib .import_module (__name__ .split ("." )[0 ])
1487+ pipeline_class = getattr (diffusers_module , cls_name , None ) if isinstance (cls_name , str ) else None
1488+
1489+ if pipeline_class is not None and pipeline_class ._load_connected_pipes :
1490+ modelcard = ModelCard .load (os .path .join (cached_folder , "README.md" ))
1491+ connected_pipes = sum ([getattr (modelcard .data , k , []) for k in CONNECTED_PIPES_KEYS ], [])
1492+ for connected_pipe_repo_id in connected_pipes :
1493+ download_kwargs = {
1494+ "cache_dir" : cache_dir ,
1495+ "force_download" : force_download ,
1496+ "proxies" : proxies ,
1497+ "local_files_only" : local_files_only ,
1498+ "token" : token ,
1499+ "variant" : variant ,
1500+ "use_safetensors" : use_safetensors ,
1501+ }
1502+ DiffusionPipeline .download (connected_pipe_repo_id , ** download_kwargs )
15011503
15021504 return cached_folder
15031505
0 commit comments