-
Couldn't load subscription status.
- Fork 6.4k
[refactor] DiffusionPipeline.download #9557
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 5 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9d57f79
update
a-r-r-o-w 5db3506
make style
a-r-r-o-w 87805a5
fix import
a-r-r-o-w 20c08f5
update test
a-r-r-o-w e821552
Merge branch 'main' into refactor-download
a-r-r-o-w a152332
Merge branch 'main' into refactor-download
sayakpaul 41ccf7c
Merge branch 'main' into refactor-download
DN6 09fa592
Update src/diffusers/pipelines/pipeline_utils.py
a-r-r-o-w 8d1314c
apply suggestions from review
a-r-r-o-w e19e800
fix import
a-r-r-o-w File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,7 @@ | |
| is_transformers_available, | ||
| logging, | ||
| ) | ||
| from ..utils.hub_utils import _check_legacy_sharding_variant_format | ||
| from ..utils.torch_utils import is_compiled_module | ||
|
|
||
|
|
||
|
|
@@ -838,3 +839,127 @@ def get_connected_passed_kwargs(prefix): | |
| ) | ||
|
|
||
| return init_kwargs | ||
|
|
||
|
|
||
| def _maybe_raise_warning_for_variant_checkpoint_format( | ||
| folder: Optional[str] = None, filenames: Optional[List[str]] = None, variant: Optional[str] = None | ||
| ) -> None: | ||
| if variant is not None and _check_legacy_sharding_variant_format( | ||
| folder=folder, filenames=filenames, variant=variant | ||
| ): | ||
| warn_msg = ( | ||
| f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. " | ||
| "Please check your files carefully:\n\n" | ||
| "- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n" | ||
| "- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n" | ||
| "If you find any files in the deprecated format:\n" | ||
| "1. Remove all existing checkpoint files for this variant.\n" | ||
| "2. Re-obtain the correct files by running `save_pretrained()`.\n\n" | ||
| "This will ensure you're using the most up-to-date and compatible checkpoint format." | ||
| ) | ||
| logger.warning(warn_msg) | ||
|
|
||
|
|
||
| def _get_custom_components_and_folders( | ||
| pretrained_model_name: str, | ||
| config_dict: Dict[str, Any], | ||
| filenames: Optional[List[str]] = None, | ||
| variant_filenames: Optional[List[str]] = None, | ||
| variant: Optional[str] = None, | ||
| ): | ||
| config_dict = config_dict.copy() | ||
|
|
||
| # retrieve all folder_names that contain relevant files | ||
| folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] | ||
|
|
||
| diffusers_module = importlib.import_module(__name__.split(".")[0]) | ||
| pipelines = getattr(diffusers_module, "pipelines") | ||
|
|
||
| # optionally create a custom component <> custom file mapping | ||
| custom_components = {} | ||
| for component in folder_names: | ||
| module_candidate = config_dict[component][0] | ||
|
|
||
| if module_candidate is None or not isinstance(module_candidate, str): | ||
| continue | ||
|
|
||
| # We compute candidate file path on the Hub. Do not use `os.path.join`. | ||
| candidate_file = f"{component}/{module_candidate}.py" | ||
|
|
||
| if candidate_file in filenames: | ||
| custom_components[component] = module_candidate | ||
| elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): | ||
| raise ValueError( | ||
| f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." | ||
| ) | ||
|
|
||
| if len(variant_filenames) == 0 and variant is not None: | ||
| error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." | ||
| raise ValueError(error_message) | ||
|
|
||
| return custom_components, folder_names | ||
|
|
||
|
|
||
| def _get_ignore_patterns( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice! |
||
| passed_components, | ||
| model_folder_names: List[str], | ||
| model_filenames: List[str], | ||
| variant_filenames: List[str], | ||
| use_safetensors: bool, | ||
| from_flax: bool, | ||
| allow_pickle: bool, | ||
| use_onnx: bool, | ||
| is_onnx: bool, | ||
| variant: Optional[str] = None, | ||
| ) -> List[str]: | ||
| if ( | ||
| use_safetensors | ||
| and not allow_pickle | ||
| and not is_safetensors_compatible( | ||
| model_filenames, passed_components=passed_components, folder_names=model_folder_names | ||
| ) | ||
| ): | ||
| raise EnvironmentError( | ||
| f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" | ||
| ) | ||
|
|
||
| if from_flax: | ||
| ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] | ||
|
|
||
| elif use_safetensors and is_safetensors_compatible( | ||
| model_filenames, passed_components=passed_components, folder_names=model_folder_names | ||
| ): | ||
| ignore_patterns = ["*.bin", "*.msgpack"] | ||
|
|
||
| use_onnx = use_onnx if use_onnx is not None else is_onnx | ||
| if not use_onnx: | ||
| ignore_patterns += ["*.onnx", "*.pb"] | ||
|
|
||
| safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} | ||
| safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} | ||
| if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: | ||
| logger.warning( | ||
| f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" | ||
| f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" | ||
| f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " | ||
| f"expected, please check your folder structure." | ||
| ) | ||
|
|
||
| else: | ||
| ignore_patterns = ["*.safetensors", "*.msgpack"] | ||
|
|
||
| use_onnx = use_onnx if use_onnx is not None else is_onnx | ||
| if not use_onnx: | ||
| ignore_patterns += ["*.onnx", "*.pb"] | ||
|
|
||
| bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} | ||
| bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} | ||
| if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: | ||
| logger.warning( | ||
| f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" | ||
| f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" | ||
| f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " | ||
| f"your folder structure." | ||
| ) | ||
|
|
||
| return ignore_patterns | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.