diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index fdfbb923bae8..3225cd0657bb 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -314,6 +314,8 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self: disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive. + checkpoint (`dict`, *optional*): + The loaded state dictionary of the model. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example @@ -362,6 +364,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self: revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", torch.float32) disable_mmap = kwargs.pop("disable_mmap", False) + checkpoint = kwargs.pop("checkpoint", None) is_legacy_loading = False @@ -386,18 +389,19 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self: from ..pipelines.pipeline_utils import _get_pipeline_class - pipeline_class = _get_pipeline_class(cls, config=None) - - checkpoint = load_single_file_checkpoint( - pretrained_model_link_or_path, - force_download=force_download, - proxies=proxies, - token=token, - cache_dir=cache_dir, - local_files_only=local_files_only, - revision=revision, - disable_mmap=disable_mmap, - ) + pipeline_class = _get_pipeline_class(cls, class_name=cls.__name__, config=None) + + if checkpoint is None: + checkpoint = load_single_file_checkpoint( + pretrained_model_link_or_path, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + disable_mmap=disable_mmap, + ) if config is None: config = fetch_diffusers_config(checkpoint) @@ -480,6 +484,11 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self: init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict} init_kwargs = {**init_kwargs, **passed_pipe_kwargs} + if len(unused_kwargs) > 0: + logger.warning( + f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." + ) + from diffusers import pipelines # remove `null` components diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 59060efade8b..d1689ced4470 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -2857,3 +2857,83 @@ def convert_lumina_attn_to_diffusers(tensor, diffusers_key): converted_state_dict[diffusers_key] = checkpoint.pop(key) return converted_state_dict + + +def get_keyword_types(keyword): + r""" + Determine the type and loading method for a given keyword. + + Parameters: + keyword (`str`): + The input keyword to classify. + + Returns: + `dict`: A dictionary containing the model format, loading method, + and various types and extra types flags. + """ + + # Initialize the status dictionary with default values + status = { + "checkpoint_format": None, + "loading_method": None, + "type": { + "other": False, + "hf_url": False, + "hf_repo": False, + "civitai_url": False, + "local": False, + }, + "extra_type": { + "url": False, + "missing_model_index": None, + }, + } + + # Check if the keyword is an HTTP or HTTPS URL + status["extra_type"]["url"] = bool(re.search(r"^(https?)://", keyword)) + + # Check if the keyword is a file + if os.path.isfile(keyword): + status["type"]["local"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = "from_single_file" + + # Check if the keyword is a directory + elif os.path.isdir(keyword): + status["type"]["local"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + if not os.path.exists(os.path.join(keyword, "model_index.json")): + status["extra_type"]["missing_model_index"] = True + + # Check if the keyword is a Civitai URL + elif keyword.startswith("https://civitai.com/"): + status["type"]["civitai_url"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = None + + # Check if the keyword starts with any valid URL prefixes + elif any(keyword.startswith(prefix) for prefix in VALID_URL_PREFIXES): + repo_id, weights_name = _extract_repo_id_and_weights_name(keyword) + if weights_name: + status["type"]["hf_url"] = True + status["checkpoint_format"] = "single_file" + status["loading_method"] = "from_single_file" + else: + status["type"]["hf_repo"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + + # Check if the keyword matches a Hugging Face repository format + elif re.match(r"^[^/]+/[^/]+$", keyword): + status["type"]["hf_repo"] = True + status["checkpoint_format"] = "diffusers" + status["loading_method"] = "from_pretrained" + + # If none of the above apply + else: + status["type"]["other"] = True + status["checkpoint_format"] = None + status["loading_method"] = None + + return status diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 4f760ee09add..3b2c80c28e87 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -18,8 +18,10 @@ from huggingface_hub.utils import validate_hf_hub_args from ..configuration_utils import ConfigMixin +from ..loaders.single_file_utils import get_keyword_types, infer_diffusers_model_type, load_single_file_checkpoint from ..models.controlnets import ControlNetUnionModel from ..utils import is_sentencepiece_available +from .animatediff import AnimateDiffPipeline, AnimateDiffSDXLPipeline from .aura_flow import AuraFlowPipeline from .cogview3 import CogView3PlusPipeline from .cogview4 import CogView4Pipeline @@ -34,10 +36,7 @@ StableDiffusionXLControlNetUnionInpaintPipeline, StableDiffusionXLControlNetUnionPipeline, ) -from .controlnet_sd3 import ( - StableDiffusion3ControlNetInpaintingPipeline, - StableDiffusion3ControlNetPipeline, -) +from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline from .flux import ( FluxControlImg2ImgPipeline, @@ -94,7 +93,9 @@ from .stable_diffusion import ( StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, + StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline, + StableDiffusionUpscalePipeline, ) from .stable_diffusion_3 import ( StableDiffusion3Img2ImgPipeline, @@ -213,6 +214,147 @@ ] ) +SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING = OrderedDict( + [ + ("animatediff_rgb", AnimateDiffPipeline), + ("animatediff_scribble", AnimateDiffPipeline), + ("animatediff_sdxl_beta", AnimateDiffSDXLPipeline), + ("animatediff_v1", AnimateDiffPipeline), + ("animatediff_v2", AnimateDiffPipeline), + ("animatediff_v3", AnimateDiffPipeline), + ("auraflow", AuraFlowPipeline), + ("autoencoder-dc-f128c512", None), + ("autoencoder-dc-f32c32", None), + ("autoencoder-dc-f32c32-sana", None), + ("autoencoder-dc-f64c128", None), + ("controlnet", StableDiffusionControlNetPipeline), + ("controlnet_xl", StableDiffusionXLControlNetPipeline), + ("controlnet_xl_large", StableDiffusionXLControlNetPipeline), + ("controlnet_xl_mid", StableDiffusionXLControlNetPipeline), + ("controlnet_xl_small", StableDiffusionXLControlNetPipeline), + ("flux-depth", FluxPipeline), + ("flux-dev", FluxPipeline), + ("flux-fill", FluxPipeline), + ("flux-schnell", FluxPipeline), + ("hunyuan-video", None), + ("inpainting", None), + ("inpainting_v2", None), + ("instruct-pix2pix", None), + ("lumina2", Lumina2Text2ImgPipeline), + ("ltx-video", None), + ("ltx-video-0.9.1", None), + ("mochi-1-preview", None), + ("playground-v2-5", StableDiffusionXLPipeline), + ("sd3", StableDiffusion3Pipeline), + ("sd35_large", StableDiffusion3Pipeline), + ("sd35_medium", StableDiffusion3Pipeline), + ("stable_cascade_stage_b", None), + ("stable_cascade_stage_b_lite", None), + ("stable_cascade_stage_c", None), + ("stable_cascade_stage_c_lite", None), + ("upscale", StableDiffusionUpscalePipeline), + ("v1", StableDiffusionPipeline), + ("v2", StableDiffusionPipeline), + ("xl_base", StableDiffusionXLPipeline), + ("xl_inpaint", None), + ("xl_refiner", StableDiffusionXLPipeline), + ] +) + +SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING = OrderedDict( + [ + ("animatediff_rgb", AnimateDiffPipeline), + ("animatediff_scribble", AnimateDiffPipeline), + ("animatediff_sdxl_beta", AnimateDiffSDXLPipeline), + ("animatediff_v1", AnimateDiffPipeline), + ("animatediff_v2", AnimateDiffPipeline), + ("animatediff_v3", AnimateDiffPipeline), + ("auraflow", None), + ("autoencoder-dc-f128c512", None), + ("autoencoder-dc-f32c32", None), + ("autoencoder-dc-f32c32-sana", None), + ("autoencoder-dc-f64c128", None), + ("controlnet", StableDiffusionControlNetImg2ImgPipeline), + ("controlnet_xl", StableDiffusionXLControlNetImg2ImgPipeline), + ("controlnet_xl_large", StableDiffusionXLControlNetImg2ImgPipeline), + ("controlnet_xl_mid", StableDiffusionXLControlNetImg2ImgPipeline), + ("controlnet_xl_small", StableDiffusionXLControlNetImg2ImgPipeline), + ("flux-depth", FluxImg2ImgPipeline), + ("flux-dev", FluxImg2ImgPipeline), + ("flux-fill", FluxImg2ImgPipeline), + ("flux-schnell", FluxImg2ImgPipeline), + ("hunyuan-video", None), + ("inpainting", None), + ("inpainting_v2", None), + ("instruct-pix2pix", StableDiffusionInstructPix2PixPipeline), + ("lumina2", None), + ("ltx-video", None), + ("ltx-video-0.9.1", None), + ("mochi-1-preview", None), + ("playground-v2-5", StableDiffusionXLImg2ImgPipeline), + ("sd3", StableDiffusion3Img2ImgPipeline), + ("sd35_large", StableDiffusion3Img2ImgPipeline), + ("sd35_medium", StableDiffusion3Img2ImgPipeline), + ("stable_cascade_stage_b", None), + ("stable_cascade_stage_b_lite", None), + ("stable_cascade_stage_c", None), + ("stable_cascade_stage_c_lite", None), + ("upscale", StableDiffusionUpscalePipeline), + ("v1", StableDiffusionImg2ImgPipeline), + ("v2", StableDiffusionImg2ImgPipeline), + ("xl_base", StableDiffusionXLImg2ImgPipeline), + ("xl_inpaint", None), + ("xl_refiner", StableDiffusionXLImg2ImgPipeline), + ] +) + +SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING = OrderedDict( + [ + ("animatediff_rgb", None), + ("animatediff_scribble", None), + ("animatediff_sdxl_beta", None), + ("animatediff_v1", None), + ("animatediff_v2", None), + ("animatediff_v3", None), + ("auraflow", None), + ("autoencoder-dc-f128c512", None), + ("autoencoder-dc-f32c32", None), + ("autoencoder-dc-f32c32-sana", None), + ("autoencoder-dc-f64c128", None), + ("controlnet", StableDiffusionControlNetInpaintPipeline), + ("controlnet_xl", None), + ("controlnet_xl_large", None), + ("controlnet_xl_mid", None), + ("controlnet_xl_small", None), + ("flux-depth", None), + ("flux-dev", None), + ("flux-fill", None), + ("flux-schnell", None), + ("hunyuan-video", None), + ("inpainting", StableDiffusionInpaintPipeline), + ("inpainting_v2", StableDiffusionInpaintPipeline), + ("instruct-pix2pix", None), + ("lumina2", None), + ("ltx-video", None), + ("ltx-video-0.9.1", None), + ("mochi-1-preview", None), + ("playground-v2-5", None), + ("sd3", None), + ("sd35_large", None), + ("sd35_medium", None), + ("stable_cascade_stage_b", None), + ("stable_cascade_stage_b_lite", None), + ("stable_cascade_stage_c", None), + ("stable_cascade_stage_c_lite", None), + ("upscale", StableDiffusionUpscalePipeline), + ("v1", None), + ("v2", None), + ("xl_base", None), + ("xl_inpaint", StableDiffusionXLInpaintPipeline), + ("xl_refiner", None), + ] +) + if is_sentencepiece_available(): from .kolors import KolorsImg2ImgPipeline, KolorsPipeline from .pag import KolorsPAGPipeline @@ -263,6 +405,78 @@ def get_model(pipeline_class_name): raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}") +@validate_hf_hub_args +def auto_load_single_checkpoint(pretrained_model_or_path, pipeline_mapping, **kwargs): + r""" + Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors` + format. The pipeline is set in evaluation mode (`model.eval()`) by default. + + Parameters: + pretrained_model_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + - A link to the `.ckpt` file (for example + `"https://huggingface.co//blob/main/.ckpt"`) on the Hub. + - A path to a *file* containing all pipeline weights. + pipeline_mapping (dict): + A dictionary mapping model types to their corresponding pipeline classes. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model with another dtype. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + original_config_file (`str`, *optional*): + The path to the original config file that was used to train the model. If not provided, the config file + will be inferred from the checkpoint file. + config (`str`, *optional*): + Can be either: + - A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline + component configs in Diffusers format. + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline + class). The overwritten components are passed directly to the pipelines `__init__` method. See example + below for more information. + """ + + # Load the checkpoint from the provided link or path + checkpoint = load_single_file_checkpoint(pretrained_model_or_path) + + # Infer the model type from the loaded checkpoint + model_type = infer_diffusers_model_type(checkpoint) + + # Get the corresponding pipeline class from the pipeline mapping + pipeline_class = pipeline_mapping.get(model_type, None) + + # For tasks not supported by this pipeline + if pipeline_class is None: + raise ValueError( + f"{model_type} is not supported in this pipeline." + "For `Text2Image`, please use `AutoPipelineForText2Image.from_pretrained`, " + "for `Image2Image` , please use `AutoPipelineForImage2Image.from_pretrained`, " + "and `inpaint` is only supported in `AutoPipelineForInpainting.from_pretrained`" + ) + + else: + # Instantiate and return the pipeline with the loaded checkpoint and any additional kwargs + return pipeline_class.from_single_file(pretrained_model_or_path, checkpoint=checkpoint, **kwargs) + + class AutoPipelineForText2Image(ConfigMixin): r""" @@ -321,6 +535,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. + - A link to the `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. @@ -402,7 +619,18 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): >>> pipeline = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") >>> image = pipeline(prompt).images[0] ``` + Example_2: + ```py + >>> from diffusers import AutoPipelineForText2Image + + >>> pipeline = AutoPipelineForText2Image.from_pretrained( + ... "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors" + ... ) """ + + # Copy the kwargs to re-use during loading connected pipeline. + kwargs_copied = kwargs.copy() + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -419,27 +647,46 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): "revision": revision, } - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - orig_class_name = config["_class_name"] - if "ControlPipeline" in orig_class_name: - to_replace = "ControlPipeline" - else: - to_replace = "Pipeline" + # Get the keyword types for the provided model path + hf_model_status = get_keyword_types(pretrained_model_or_path) - if "controlnet" in kwargs: - if isinstance(kwargs["controlnet"], ControlNetUnionModel): - orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline") - else: - orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline") - if "enable_pag" in kwargs: - enable_pag = kwargs.pop("enable_pag") - if enable_pag: - orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline") - - text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) + # Obtain a loading method. One of the following [None, "from_pretrained", "from_single_file"] + load_method_name = hf_model_status["loading_method"] - kwargs = {**load_config_kwargs, **kwargs} - return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) + # Check if it is loadable. + if load_method_name is not None: + if load_method_name == "from_single_file": + # For single file checkpoint + return auto_load_single_checkpoint( + pretrained_model_or_path=pretrained_model_or_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_TEXT2IMAGE_PIPELINE_MAPPING, + **kwargs_copied, + ) + else: + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + orig_class_name = config["_class_name"] + if "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" + + if "controlnet" in kwargs: + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline") + else: + orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline") + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline") + + text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) + + kwargs = {**load_config_kwargs, **kwargs} + return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) + else: + # Exception handling when loading is not possible + raise ValueError(f"Invalid path or URL: {pretrained_model_or_path}") @classmethod def from_pipe(cls, pipeline, **kwargs): @@ -617,6 +864,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. + - A link to the `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. @@ -698,7 +948,19 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): >>> pipeline = AutoPipelineForImage2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5") >>> image = pipeline(prompt, image).images[0] ``` + Example_2: + ```py + >>> from diffusers import AutoPipelineForImage2Image + + >>> pipeline = AutoPipelineForImage2Image.from_pretrained( + ... "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors" + ... ) + >>> image = pipeline(prompt, image).images[0] """ + + # Copy the kwargs to re-use during loading connected pipeline. + kwargs_copied = kwargs.copy() + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -715,37 +977,56 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): "revision": revision, } - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - orig_class_name = config["_class_name"] - - # the `orig_class_name` can be: - # `- *Pipeline` (for regular text-to-image checkpoint) - # - `*ControlPipeline` (for Flux tools specific checkpoint) - # `- *Img2ImgPipeline` (for refiner checkpoint) - if "Img2Img" in orig_class_name: - to_replace = "Img2ImgPipeline" - elif "ControlPipeline" in orig_class_name: - to_replace = "ControlPipeline" - else: - to_replace = "Pipeline" + # Get the keyword types for the provided model path + hf_model_status = get_keyword_types(pretrained_model_or_path) - if "controlnet" in kwargs: - if isinstance(kwargs["controlnet"], ControlNetUnionModel): - orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) - else: - orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) - if "enable_pag" in kwargs: - enable_pag = kwargs.pop("enable_pag") - if enable_pag: - orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + # Obtain a loading method. One of the following [None, "from_pretrained", "from_single_file"] + load_method_name = hf_model_status["loading_method"] - if to_replace == "ControlPipeline": - orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline") - - image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) - - kwargs = {**load_config_kwargs, **kwargs} - return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) + # Check if it is loadable. + if load_method_name is not None: + if load_method_name == "from_single_file": + # For single file checkpoint + return auto_load_single_checkpoint( + pretrained_model_or_path=pretrained_model_or_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_IMAGE2IMAGE_PIPELINE_MAPPING, + **kwargs_copied, + ) + else: + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + orig_class_name = config["_class_name"] + + # the `orig_class_name` can be: + # `- *Pipeline` (for regular text-to-image checkpoint) + # - `*ControlPipeline` (for Flux tools specific checkpoint) + # `- *Img2ImgPipeline` (for refiner checkpoint) + if "Img2Img" in orig_class_name: + to_replace = "Img2ImgPipeline" + elif "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" + + if "controlnet" in kwargs: + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) + else: + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + + if to_replace == "ControlPipeline": + orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline") + + image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) + + kwargs = {**load_config_kwargs, **kwargs} + return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) + else: + # Exception handling when loading is not possible + raise ValueError(f"Invalid path or URL: {pretrained_model_or_path}") @classmethod def from_pipe(cls, pipeline, **kwargs): @@ -928,6 +1209,9 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. + - A link to the `.safetensors` file (for example + `"https://huggingface.co//blob/main/.safetensors"`) on the Hub. + - A path to a *file* containing all pipeline weights. torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. @@ -1010,6 +1294,10 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0] ``` """ + + # Copy the kwargs to re-use during loading connected pipeline. + kwargs_copied = kwargs.copy() + cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) @@ -1025,36 +1313,54 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): "local_files_only": local_files_only, "revision": revision, } - - config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) - orig_class_name = config["_class_name"] - - # The `orig_class_name`` can be: - # `- *InpaintPipeline` (for inpaint-specific checkpoint) - # - `*ControlPipeline` (for Flux tools specific checkpoint) - # - or *Pipeline (for regular text-to-image checkpoint) - if "Inpaint" in orig_class_name: - to_replace = "InpaintPipeline" - elif "ControlPipeline" in orig_class_name: - to_replace = "ControlPipeline" - else: - to_replace = "Pipeline" - - if "controlnet" in kwargs: - if isinstance(kwargs["controlnet"], ControlNetUnionModel): - orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) + # Get the keyword types for the provided model path + hf_model_status = get_keyword_types(pretrained_model_or_path) + + # Obtain a loading method. One of the following [None, "from_pretrained", "from_single_file"] + load_method_name = hf_model_status["loading_method"] + + # Check if it is loadable. + if load_method_name is not None: + if load_method_name == "from_single_file": + # For single file checkpoint + return auto_load_single_checkpoint( + pretrained_model_or_path=pretrained_model_or_path, + pipeline_mapping=SINGLE_FILE_CHECKPOINT_INPAINT_PIPELINE_MAPPING, + **kwargs_copied, + ) else: - orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) - if "enable_pag" in kwargs: - enable_pag = kwargs.pop("enable_pag") - if enable_pag: - orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) - if to_replace == "ControlPipeline": - orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline") - inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) - - kwargs = {**load_config_kwargs, **kwargs} - return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs) + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) + orig_class_name = config["_class_name"] + + # The `orig_class_name`` can be: + # `- *InpaintPipeline` (for inpaint-specific checkpoint) + # - `*ControlPipeline` (for Flux tools specific checkpoint) + # - or *Pipeline (for regular text-to-image checkpoint) + if "Inpaint" in orig_class_name: + to_replace = "InpaintPipeline" + elif "ControlPipeline" in orig_class_name: + to_replace = "ControlPipeline" + else: + to_replace = "Pipeline" + + if "controlnet" in kwargs: + if isinstance(kwargs["controlnet"], ControlNetUnionModel): + orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace) + else: + orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace) + if "enable_pag" in kwargs: + enable_pag = kwargs.pop("enable_pag") + if enable_pag: + orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace) + if to_replace == "ControlPipeline": + orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline") + inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) + + kwargs = {**load_config_kwargs, **kwargs} + return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs) + else: + # Exception handling when loading is not possible + raise ValueError(f"Invalid path or URL: {pretrained_model_or_path}") @classmethod def from_pipe(cls, pipeline, **kwargs): diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1b306b1805d8..ac70f6ba1ae5 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -44,6 +44,8 @@ from .. import __version__ from ..configuration_utils import ConfigMixin +from ..loaders.single_file import FromSingleFileMixin +from ..loaders.single_file_utils import get_keyword_types, load_single_file_checkpoint from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin @@ -136,7 +138,7 @@ class AudioPipelineOutput(BaseOutput): audios: np.ndarray -class DiffusionPipeline(ConfigMixin, PushToHubMixin): +class DiffusionPipeline(ConfigMixin, PushToHubMixin, FromSingleFileMixin): r""" Base class for all pipelines. @@ -674,6 +676,27 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) >>> pipeline.scheduler = scheduler ``` + Examples_2: + + ```py + >>> from diffusers import StableDiffusionPipeline + + >>> # Download pipeline from huggingface.co and cache. + >>> pipeline = StableDiffusionPipeline.from_single_file( + ... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" + ... ) + + >>> # Download pipeline from local file + >>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt + >>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt") + + >>> # Enable float16 and move to GPU + >>> pipeline = StableDiffusionPipeline.from_single_file( + ... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt", + ... torch_dtype=torch.float16, + ... ) + >>> pipeline.to("cuda") + ``` """ # Copy the kwargs to re-use during loading connected pipeline. kwargs_copied = kwargs.copy() @@ -758,6 +781,37 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if load_connected_pipeline: raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.") + # Retrieve information about the path or repo ID + hf_model_status = get_keyword_types(pretrained_model_name_or_path) + + # Obtain a loading method. One of the following [None, "from_pretrained", "from_single_file"] + load_method_name = hf_model_status["loading_method"] + + if load_method_name is None: + # Raise an error if the path is invalid + raise ValueError(f"Invalid path or URL: {pretrained_model_name_or_path}") + + # Load the pipeline from a single file + elif load_method_name == "from_single_file": + # The arguments for the __init__ method of `DiffusionPipeline` are keyword arguments, so they cannot be loaded from `from_single_file` + if cls.__name__ == "DiffusionPipeline": + # import it here to avoid circular import + from .stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt + + checkpoint = load_single_file_checkpoint( + pretrained_model_name_or_path, + force_download=force_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + return download_from_original_stable_diffusion_ckpt(checkpoint, **kwargs) + else: + return cls.from_single_file(pretrained_model_name_or_path, **kwargs) + # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): @@ -1486,8 +1540,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: if load_components_from_hub and not trust_remote_code: raise ValueError( - f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly " - f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n" + f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly " + f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n" f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." )