|
71 | 71 | CUSTOM_PIPELINE_FILE_NAME, |
72 | 72 | LOADABLE_CLASSES, |
73 | 73 | _fetch_class_library_tuple, |
| 74 | + _get_custom_components_and_folders, |
74 | 75 | _get_custom_pipeline_class, |
75 | 76 | _get_final_device_map, |
| 77 | + _get_ignore_patterns, |
76 | 78 | _get_pipeline_class, |
77 | 79 | _identify_model_variants, |
78 | 80 | _maybe_raise_warning_for_inpainting, |
79 | 81 | _resolve_custom_pipeline_and_cls, |
80 | 82 | _unwrap_model, |
81 | 83 | _update_init_kwargs_with_connected_pipeline, |
82 | | - is_safetensors_compatible, |
83 | 84 | load_sub_model, |
84 | 85 | maybe_raise_or_warn, |
85 | 86 | variant_compatible_siblings, |
@@ -1298,44 +1299,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: |
1298 | 1299 | config_dict = cls._dict_from_json_file(config_file) |
1299 | 1300 | ignore_filenames = config_dict.pop("_ignore_files", []) |
1300 | 1301 |
|
1301 | | - # retrieve all folder_names that contain relevant files |
1302 | | - folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] |
1303 | | - |
1304 | | - diffusers_module = importlib.import_module(__name__.split(".")[0]) |
1305 | | - pipelines = getattr(diffusers_module, "pipelines") |
1306 | | - |
1307 | | - # optionally create a custom component <> custom file mapping |
1308 | | - custom_components = {} |
1309 | | - for component in folder_names: |
1310 | | - module_candidate = config_dict[component][0] |
1311 | | - |
1312 | | - if module_candidate is None or not isinstance(module_candidate, str): |
1313 | | - continue |
1314 | | - |
1315 | | - # We compute candidate file path on the Hub. Do not use `os.path.join`. |
1316 | | - candidate_file = f"{component}/{module_candidate}.py" |
1317 | | - |
1318 | | - if candidate_file in filenames: |
1319 | | - custom_components[component] = module_candidate |
1320 | | - elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): |
1321 | | - raise ValueError( |
1322 | | - f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." |
1323 | | - ) |
1324 | | - |
1325 | | - if len(variant_filenames) == 0 and variant is not None: |
1326 | | - error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." |
1327 | | - raise ValueError(error_message) |
1328 | | - |
1329 | 1302 | # remove ignored filenames |
1330 | 1303 | model_filenames = set(model_filenames) - set(ignore_filenames) |
1331 | 1304 | variant_filenames = set(variant_filenames) - set(ignore_filenames) |
1332 | 1305 |
|
1333 | | - # if the whole pipeline is cached we don't have to ping the Hub |
1334 | 1306 | if revision in DEPRECATED_REVISION_ARGS and version.parse( |
1335 | 1307 | version.parse(__version__).base_version |
1336 | 1308 | ) >= version.parse("0.22.0"): |
1337 | 1309 | warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) |
1338 | 1310 |
|
| 1311 | + custom_components, folder_names = _get_custom_components_and_folders( |
| 1312 | + pretrained_model_name, config_dict, filenames, variant_filenames, variant |
| 1313 | + ) |
1339 | 1314 | model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} |
1340 | 1315 |
|
1341 | 1316 | custom_class_name = None |
@@ -1395,49 +1370,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: |
1395 | 1370 | expected_components, _ = cls._get_signature_keys(pipeline_class) |
1396 | 1371 | passed_components = [k for k in expected_components if k in kwargs] |
1397 | 1372 |
|
1398 | | - if ( |
1399 | | - use_safetensors |
1400 | | - and not allow_pickle |
1401 | | - and not is_safetensors_compatible( |
1402 | | - model_filenames, passed_components=passed_components, folder_names=model_folder_names |
1403 | | - ) |
1404 | | - ): |
1405 | | - raise EnvironmentError( |
1406 | | - f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" |
1407 | | - ) |
1408 | | - if from_flax: |
1409 | | - ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] |
1410 | | - elif use_safetensors and is_safetensors_compatible( |
1411 | | - model_filenames, passed_components=passed_components, folder_names=model_folder_names |
1412 | | - ): |
1413 | | - ignore_patterns = ["*.bin", "*.msgpack"] |
1414 | | - |
1415 | | - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx |
1416 | | - if not use_onnx: |
1417 | | - ignore_patterns += ["*.onnx", "*.pb"] |
1418 | | - |
1419 | | - safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} |
1420 | | - safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} |
1421 | | - if ( |
1422 | | - len(safetensors_variant_filenames) > 0 |
1423 | | - and safetensors_model_filenames != safetensors_variant_filenames |
1424 | | - ): |
1425 | | - logger.warning( |
1426 | | - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." |
1427 | | - ) |
1428 | | - else: |
1429 | | - ignore_patterns = ["*.safetensors", "*.msgpack"] |
1430 | | - |
1431 | | - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx |
1432 | | - if not use_onnx: |
1433 | | - ignore_patterns += ["*.onnx", "*.pb"] |
1434 | | - |
1435 | | - bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} |
1436 | | - bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} |
1437 | | - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: |
1438 | | - logger.warning( |
1439 | | - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." |
1440 | | - ) |
| 1373 | + # retrieve all patterns that should not be downloaded and error out when needed |
| 1374 | + ignore_patterns = _get_ignore_patterns( |
| 1375 | + passed_components, |
| 1376 | + model_folder_names, |
| 1377 | + model_filenames, |
| 1378 | + variant_filenames, |
| 1379 | + use_safetensors, |
| 1380 | + from_flax, |
| 1381 | + allow_pickle, |
| 1382 | + use_onnx, |
| 1383 | + pipeline_class._is_onnx, |
| 1384 | + variant, |
| 1385 | + ) |
1441 | 1386 |
|
1442 | 1387 | # Don't download any objects that are passed |
1443 | 1388 | allow_patterns = [ |
|
0 commit comments