diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index edac6bfd9e4e..59a473e32ae1 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -224,7 +224,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, scheduler: Union[ DDIMScheduler, diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py index 1a75d658b3ad..fd4d5346f7c1 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py @@ -246,7 +246,7 @@ def __init__( vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, + unet: Union[UNet2DConditionModel, UNetMotionModel], motion_adapter: MotionAdapter, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: Union[ diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index f01c8cc4674d..5ee712b5f116 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -232,8 +232,8 @@ def __init__( Tuple[HunyuanDiT2DControlNetModel], HunyuanDiT2DMultiControlNetModel, ], - text_encoder_2=T5EncoderModel, - tokenizer_2=MT5Tokenizer, + text_encoder_2: Optional[T5EncoderModel] = None, + tokenizer_2: Optional[MT5Tokenizer] = None, requires_safety_checker: bool = True, ): super().__init__() diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 7f85fcc1d90d..7f7acd882b59 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -17,10 +17,10 @@ import torch from transformers import ( - BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipImageProcessor, + SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) @@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline( Provides additional conditioning to the `unet` during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. - image_encoder (`PreTrainedModel`, *optional*): + image_encoder (`SiglipVisionModel`, *optional*): Pre-trained Vision Model for IP Adapter. - feature_extractor (`BaseImageProcessor`, *optional*): + feature_extractor (`SiglipImageProcessor`, *optional*): Image processor for IP Adapter. """ @@ -202,8 +202,8 @@ def __init__( controlnet: Union[ SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel ], - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: Optional[SiglipVisionModel] = None, + feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() if isinstance(controlnet, (list, tuple)): diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 35e47f4d650e..cb35f67fa112 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -17,10 +17,10 @@ import torch from transformers import ( - BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipImageProcessor, + SiglipModel, T5EncoderModel, T5TokenizerFast, ) @@ -223,8 +223,8 @@ def __init__( controlnet: Union[ SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel ], - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: SiglipModel = None, + feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py index ed342f66804a..34b2a3945572 100644 --- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -17,6 +17,8 @@ import torch +from ...models import UNet1DModel +from ...schedulers import SchedulerMixin from ...utils import is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline @@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline): model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 1b424f5742f2..1fd8ce4e6570 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -16,6 +16,7 @@ import torch +from ...models import UNet2DModel from ...schedulers import DDIMScheduler from ...utils import is_torch_xla_available from ...utils.torch_utils import randn_tensor @@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline): model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler): super().__init__() # make sure scheduler can always be converted to DDIM diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index e58a53b5b7e8..1c5ac4baeae0 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -17,6 +17,8 @@ import torch +from ...models import UNet2DModel +from ...schedulers import DDPMScheduler from ...utils import is_torch_xla_available from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline): model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py b/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py index 101d315dfe59..843528a532f1 100644 --- a/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py @@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline): scheduler: RePaintScheduler model_cpu_offload_seq = "unet" - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index 6a5cf298d2d4..febf2b0392cc 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -207,8 +207,8 @@ def __init__( safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, - text_encoder_2=T5EncoderModel, - tokenizer_2=MT5Tokenizer, + text_encoder_2: Optional[T5EncoderModel] = None, + tokenizer_2: Optional[MT5Tokenizer] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 4f6793e17b37..b50079532f94 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import AutoModel, AutoTokenizer +from transformers import GemmaPreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor @@ -144,13 +144,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline): Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`AutoModel`]): - Frozen text-encoder. Lumina-T2I uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the - [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. - tokenizer (`AutoModel`): - Tokenizer of class - [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). + text_encoder ([`GemmaPreTrainedModel`]): + Frozen Gemma text-encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Gemma tokenizer. transformer ([`Transformer2DModel`]): A text conditioned `Transformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): @@ -185,8 +182,8 @@ def __init__( transformer: LuminaNextDiT2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, - text_encoder: AutoModel, - tokenizer: AutoTokenizer, + text_encoder: GemmaPreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], ): super().__init__() diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 40e42bbe6ba6..514192cb70c7 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -17,7 +17,7 @@ import numpy as np import torch -from transformers import AutoModel, AutoTokenizer +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...image_processor import VaeImageProcessor from ...loaders import Lumina2LoraLoaderMixin @@ -143,13 +143,10 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline, Lumina2LoraLoaderMixin): Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`AutoModel`]): - Frozen text-encoder. Lumina-T2I uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the - [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. - tokenizer (`AutoModel`): - Tokenizer of class - [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). + text_encoder ([`Gemma2PreTrainedModel`]): + Frozen Gemma2 text-encoder. + tokenizer (`GemmaTokenizer` or `GemmaTokenizerFast`): + Gemma tokenizer. transformer ([`Transformer2DModel`]): A text conditioned `Transformer2DModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): @@ -165,8 +162,8 @@ def __init__( transformer: Lumina2Transformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, - text_encoder: AutoModel, - tokenizer: AutoTokenizer, + text_encoder: Gemma2PreTrainedModel, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], ): super().__init__() diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index d0bbb46b09e7..030ab6db7391 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -20,7 +20,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor @@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin): def __init__( self, - tokenizer: AutoTokenizer, - text_encoder: AutoModelForCausalLM, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, vae: AutoencoderDC, transformer: SanaTransformer2DModel, scheduler: FlowMatchEulerDiscreteScheduler, diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 9a9afa198b4c..0e2cbb32d3c1 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -17,7 +17,7 @@ import re import warnings from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin import requests import torch @@ -1059,3 +1059,76 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict): break if has_transformers_component and not is_transformers_version(">", "4.47.1"): raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") + + +def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool: + """ + Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of + the correct type as well. + """ + if not isinstance(class_or_tuple, tuple): + class_or_tuple = (class_or_tuple,) + + # Unpack unions + unpacked_class_or_tuple = [] + for t in class_or_tuple: + if get_origin(t) is Union: + unpacked_class_or_tuple.extend(get_args(t)) + else: + unpacked_class_or_tuple.append(t) + class_or_tuple = tuple(unpacked_class_or_tuple) + + if Any in class_or_tuple: + return True + + obj_type = type(obj) + # Classes with obj's type + class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)} + + # Singular types (e.g. int, ControlNet, ...) + # Untyped collections (e.g. List, but not List[int]) + elem_class_or_tuple = {get_args(t) for t in class_or_tuple} + if () in elem_class_or_tuple: + return True + # Typed lists or sets + elif obj_type in (list, set): + return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple) + # Typed tuples + elif obj_type is tuple: + return any( + # Tuples with any length and single type (e.g. Tuple[int, ...]) + (len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj)) + or + # Tuples with fixed length and any types (e.g. Tuple[int, str]) + (len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t))) + for t in elem_class_or_tuple + ) + # Typed dicts + elif obj_type is dict: + return any( + all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items()) + for kt, vt in elem_class_or_tuple + ) + + else: + return False + + +def _get_detailed_type(obj: Any) -> Type: + """ + Gets a detailed type for an object, including nested types for collections. + """ + obj_type = type(obj) + + if obj_type in (list, set): + obj_origin_type = List if obj_type is list else Set + elems_type = Union[tuple({_get_detailed_type(x) for x in obj})] + return obj_origin_type[elems_type] + elif obj_type is tuple: + return Tuple[tuple(_get_detailed_type(x) for x in obj)] + elif obj_type is dict: + keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})] + values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})] + return Dict[keys_type, values_type] + else: + return obj_type diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 26bd938b2734..90a05e97f614 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import fnmatch import importlib import inspect @@ -79,10 +78,12 @@ _fetch_class_library_tuple, _get_custom_components_and_folders, _get_custom_pipeline_class, + _get_detailed_type, _get_final_device_map, _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, + _is_valid_type, _maybe_raise_error_for_incorrect_transformers, _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, @@ -876,26 +877,6 @@ def load_module(name, value): init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} - for key in init_dict.keys(): - if key not in passed_class_obj: - continue - if "scheduler" in key: - continue - - class_obj = passed_class_obj[key] - _expected_class_types = [] - for expected_type in expected_types[key]: - if isinstance(expected_type, enum.EnumMeta): - _expected_class_types.extend(expected_type.__members__.keys()) - else: - _expected_class_types.append(expected_type.__name__) - - _is_valid_type = class_obj.__class__.__name__ in _expected_class_types - if not _is_valid_type: - logger.warning( - f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}." - ) - # Special case: safety_checker must be loaded separately when using `from_flax` if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: raise NotImplementedError( @@ -1015,10 +996,26 @@ def load_module(name, value): f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." ) - # 10. Instantiate the pipeline + # 10. Type checking init arguments + for kw, arg in init_kwargs.items(): + # Too complex to validate with type annotation alone + if "scheduler" in kw: + continue + # Many tokenizer annotations don't include its "Fast" variant, so skip this + # e.g T5Tokenizer but not T5TokenizerFast + elif "tokenizer" in kw: + continue + elif ( + arg is not None # Skip if None + and not expected_types[kw] == (inspect.Signature.empty,) # Skip if no type annotations + and not _is_valid_type(arg, expected_types[kw]) # Check type + ): + logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {_get_detailed_type(arg)}.") + + # 11. Instantiate the pipeline model = pipeline_class(**init_kwargs) - # 11. Save where the model was instantiated from + # 12. Save where the model was instantiated from model.register_to_config(_name_or_path=pretrained_model_name_or_path) if device_map is not None: setattr(model, "hf_device_map", final_device_map) diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 11c63be52a87..460e7e2a237a 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -20,7 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PixArtImageProcessor @@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): def __init__( self, - tokenizer: AutoTokenizer, - text_encoder: AutoModelForCausalLM, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, vae: AutoencoderDC, transformer: SanaTransformer2DModel, scheduler: DPMSolverMultistepScheduler, diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index e3b9ec44005a..38f1c4314e4f 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -15,7 +15,7 @@ from typing import Callable, Dict, List, Optional, Union import torch -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModelWithProjection, CLIPTokenizer from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler @@ -65,7 +65,7 @@ class StableCascadeDecoderPipeline(DiffusionPipeline): Args: tokenizer (`CLIPTokenizer`): The CLIP tokenizer. - text_encoder (`CLIPTextModel`): + text_encoder (`CLIPTextModelWithProjection`): The CLIP text encoder. decoder ([`StableCascadeUNet`]): The Stable Cascade decoder unet. @@ -93,7 +93,7 @@ def __init__( self, decoder: StableCascadeUNet, tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, + text_encoder: CLIPTextModelWithProjection, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, latent_dim_scale: float = 10.67, diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py index 6724b60cc424..28a74ab83733 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py @@ -15,7 +15,7 @@ import PIL import torch -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler @@ -52,7 +52,7 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): Args: tokenizer (`CLIPTokenizer`): The decoder tokenizer to be used for text inputs. - text_encoder (`CLIPTextModel`): + text_encoder (`CLIPTextModelWithProjection`): The decoder text encoder to be used for text inputs. decoder (`StableCascadeUNet`): The decoder model to be used for decoder image generation pipeline. @@ -60,14 +60,18 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): The scheduler to be used for decoder image generation pipeline. vqgan (`PaellaVQModel`): The VQGAN model to be used for decoder image generation pipeline. - feature_extractor ([`~transformers.CLIPImageProcessor`]): - Model that extracts features from generated images to be used as inputs for the `image_encoder`. - image_encoder ([`CLIPVisionModelWithProjection`]): - Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). prior_prior (`StableCascadeUNet`): The prior model to be used for prior pipeline. + prior_text_encoder (`CLIPTextModelWithProjection`): + The prior text encoder to be used for text inputs. + prior_tokenizer (`CLIPTokenizer`): + The prior tokenizer to be used for text inputs. prior_scheduler (`DDPMWuerstchenScheduler`): The scheduler to be used for prior pipeline. + prior_feature_extractor ([`~transformers.CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `image_encoder`. + prior_image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). """ _load_connected_pipes = True @@ -76,12 +80,12 @@ class StableCascadeCombinedPipeline(DiffusionPipeline): def __init__( self, tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModel, + text_encoder: CLIPTextModelWithProjection, decoder: StableCascadeUNet, scheduler: DDPMWuerstchenScheduler, vqgan: PaellaVQModel, prior_prior: StableCascadeUNet, - prior_text_encoder: CLIPTextModel, + prior_text_encoder: CLIPTextModelWithProjection, prior_tokenizer: CLIPTokenizer, prior_scheduler: DDPMWuerstchenScheduler, prior_feature_extractor: Optional[CLIPImageProcessor] = None, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 07d82251d4ba..be01e0acbf18 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -141,7 +141,7 @@ def __init__( image_noising_scheduler: KarrasDiffusionSchedulers, # regular denoising components tokenizer: CLIPTokenizer, - text_encoder: CLIPTextModelWithProjection, + text_encoder: CLIPTextModel, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, # vae diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 588abc8ef2dc..4618d384cbd7 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -17,10 +17,10 @@ import torch from transformers import ( - BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipImageProcessor, + SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) @@ -176,9 +176,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - image_encoder (`PreTrainedModel`, *optional*): + image_encoder (`SiglipVisionModel`, *optional*): Pre-trained Vision Model for IP Adapter. - feature_extractor (`BaseImageProcessor`, *optional*): + feature_extractor (`SiglipImageProcessor`, *optional*): Image processor for IP Adapter. """ @@ -197,8 +197,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: SiglipVisionModel = None, + feature_extractor: SiglipImageProcessor = None, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 3d3c8b6781fc..19bdc9792e23 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -18,10 +18,10 @@ import PIL.Image import torch from transformers import ( - BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipImageProcessor, + SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) @@ -197,6 +197,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`SiglipVisionModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`SiglipImageProcessor`, *optional*): + Image processor for IP Adapter. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" @@ -214,8 +218,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: Optional[SiglipVisionModel] = None, + feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index 71103187f47b..c69fb90a4c5e 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -17,10 +17,10 @@ import torch from transformers import ( - BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, - PreTrainedModel, + SiglipImageProcessor, + SiglipVisionModel, T5EncoderModel, T5TokenizerFast, ) @@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - image_encoder (`PreTrainedModel`, *optional*): + image_encoder (`SiglipVisionModel`, *optional*): Pre-trained Vision Model for IP Adapter. - feature_extractor (`BaseImageProcessor`, *optional*): + feature_extractor (`SiglipImageProcessor`, *optional*): Image processor for IP Adapter. """ @@ -217,8 +217,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, - image_encoder: PreTrainedModel = None, - feature_extractor: BaseImageProcessor = None, + image_encoder: Optional[SiglipVisionModel] = None, + feature_extractor: Optional[SiglipImageProcessor] = None, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py index 24e11bff3052..1f29f577f8e0 100755 --- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -19,15 +19,31 @@ import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPTokenizerFast, +) from ...image_processor import VaeImageProcessor -from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import ( + StableDiffusionLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder -from ...schedulers import LMSDiscreteScheduler -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -95,13 +111,13 @@ class StableDiffusionKDiffusionPipeline( def __init__( self, - vae, - text_encoder, - tokenizer, - unet, - scheduler, - safety_checker, - feature_extractor, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast], + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() diff --git a/tests/fixtures/custom_pipeline/pipeline.py b/tests/fixtures/custom_pipeline/pipeline.py index 601f51b1263e..e197cb6859fa 100644 --- a/tests/fixtures/custom_pipeline/pipeline.py +++ b/tests/fixtures/custom_pipeline/pipeline.py @@ -18,7 +18,7 @@ import torch -from diffusers import DiffusionPipeline, ImagePipelineOutput +from diffusers import DiffusionPipeline, ImagePipelineOutput, SchedulerMixin, UNet2DModel class CustomLocalPipeline(DiffusionPipeline): @@ -33,7 +33,7 @@ class CustomLocalPipeline(DiffusionPipeline): [`DDPMScheduler`], or [`DDIMScheduler`]. """ - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/tests/fixtures/custom_pipeline/what_ever.py b/tests/fixtures/custom_pipeline/what_ever.py index 8ceeb4211e37..bbe7f4f16bd8 100644 --- a/tests/fixtures/custom_pipeline/what_ever.py +++ b/tests/fixtures/custom_pipeline/what_ever.py @@ -18,6 +18,7 @@ import torch +from diffusers import SchedulerMixin, UNet2DModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -33,7 +34,7 @@ class CustomLocalPipeline(DiffusionPipeline): [`DDPMScheduler`], or [`DDIMScheduler`]. """ - def __init__(self, unet, scheduler): + def __init__(self, unet: UNet2DModel, scheduler: SchedulerMixin): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py index 3e783b80e7e4..aa0571559b45 100644 --- a/tests/pipelines/lumina2/test_pipeline_lumina2.py +++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py @@ -91,10 +91,10 @@ def get_dummy_components(self): text_encoder = Gemma2Model(config) components = { - "transformer": transformer.eval(), + "transformer": transformer, "vae": vae.eval(), "scheduler": scheduler, - "text_encoder": text_encoder.eval(), + "text_encoder": text_encoder, "tokenizer": tokenizer, } return components