Skip to content

Commit 9c7e205

Browse files
guiyrthlky
andauthored
Comprehensive type checking for from_pretrained kwargs (huggingface#10758)
* More robust from_pretrained init_kwargs type checking * Corrected for Python 3.10 * Type checks subclasses and fixed type warnings * More type corrections and skip tokenizer type checking * make style && make quality * Updated docs and types for Lumina pipelines * Fixed check for empty signature * changed location of helper functions * make style --------- Co-authored-by: hlky <[email protected]>
1 parent 64dec70 commit 9c7e205

26 files changed

+208
-114
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
vae: AutoencoderKL,
225225
text_encoder: CLIPTextModel,
226226
tokenizer: CLIPTokenizer,
227-
unet: UNet2DConditionModel,
227+
unet: Union[UNet2DConditionModel, UNetMotionModel],
228228
motion_adapter: MotionAdapter,
229229
scheduler: Union[
230230
DDIMScheduler,

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
vae: AutoencoderKL,
247247
text_encoder: CLIPTextModel,
248248
tokenizer: CLIPTokenizer,
249-
unet: UNet2DConditionModel,
249+
unet: Union[UNet2DConditionModel, UNetMotionModel],
250250
motion_adapter: MotionAdapter,
251251
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
252252
scheduler: Union[

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ def __init__(
232232
Tuple[HunyuanDiT2DControlNetModel],
233233
HunyuanDiT2DMultiControlNetModel,
234234
],
235-
text_encoder_2=T5EncoderModel,
236-
tokenizer_2=MT5Tokenizer,
235+
text_encoder_2: Optional[T5EncoderModel] = None,
236+
tokenizer_2: Optional[MT5Tokenizer] = None,
237237
requires_safety_checker: bool = True,
238238
):
239239
super().__init__()

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
from transformers import (
20-
BaseImageProcessor,
2120
CLIPTextModelWithProjection,
2221
CLIPTokenizer,
23-
PreTrainedModel,
22+
SiglipImageProcessor,
23+
SiglipVisionModel,
2424
T5EncoderModel,
2525
T5TokenizerFast,
2626
)
@@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
178178
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
179179
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
180180
additional conditioning.
181-
image_encoder (`PreTrainedModel`, *optional*):
181+
image_encoder (`SiglipVisionModel`, *optional*):
182182
Pre-trained Vision Model for IP Adapter.
183-
feature_extractor (`BaseImageProcessor`, *optional*):
183+
feature_extractor (`SiglipImageProcessor`, *optional*):
184184
Image processor for IP Adapter.
185185
"""
186186

@@ -202,8 +202,8 @@ def __init__(
202202
controlnet: Union[
203203
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
204204
],
205-
image_encoder: PreTrainedModel = None,
206-
feature_extractor: BaseImageProcessor = None,
205+
image_encoder: Optional[SiglipVisionModel] = None,
206+
feature_extractor: Optional[SiglipImageProcessor] = None,
207207
):
208208
super().__init__()
209209
if isinstance(controlnet, (list, tuple)):

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
from transformers import (
20-
BaseImageProcessor,
2120
CLIPTextModelWithProjection,
2221
CLIPTokenizer,
23-
PreTrainedModel,
22+
SiglipImageProcessor,
23+
SiglipModel,
2424
T5EncoderModel,
2525
T5TokenizerFast,
2626
)
@@ -223,8 +223,8 @@ def __init__(
223223
controlnet: Union[
224224
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
225225
],
226-
image_encoder: PreTrainedModel = None,
227-
feature_extractor: BaseImageProcessor = None,
226+
image_encoder: SiglipModel = None,
227+
feature_extractor: Optional[SiglipImageProcessor] = None,
228228
):
229229
super().__init__()
230230

src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from ...models import UNet1DModel
21+
from ...schedulers import SchedulerMixin
2022
from ...utils import is_torch_xla_available, logging
2123
from ...utils.torch_utils import randn_tensor
2224
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
@@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
4951

5052
model_cpu_offload_seq = "unet"
5153

52-
def __init__(self, unet, scheduler):
54+
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
5355
super().__init__()
5456
self.register_modules(unet=unet, scheduler=scheduler)
5557

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818

19+
from ...models import UNet2DModel
1920
from ...schedulers import DDIMScheduler
2021
from ...utils import is_torch_xla_available
2122
from ...utils.torch_utils import randn_tensor
@@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):
4748

4849
model_cpu_offload_seq = "unet"
4950

50-
def __init__(self, unet, scheduler):
51+
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
5152
super().__init__()
5253

5354
# make sure scheduler can always be converted to DDIM

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from ...models import UNet2DModel
21+
from ...schedulers import DDPMScheduler
2022
from ...utils import is_torch_xla_available
2123
from ...utils.torch_utils import randn_tensor
2224
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):
4749

4850
model_cpu_offload_seq = "unet"
4951

50-
def __init__(self, unet, scheduler):
52+
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
5153
super().__init__()
5254
self.register_modules(unet=unet, scheduler=scheduler)
5355

src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
9191
scheduler: RePaintScheduler
9292
model_cpu_offload_seq = "unet"
9393

94-
def __init__(self, unet, scheduler):
94+
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
9595
super().__init__()
9696
self.register_modules(unet=unet, scheduler=scheduler)
9797

src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207
safety_checker: StableDiffusionSafetyChecker,
208208
feature_extractor: CLIPImageProcessor,
209209
requires_safety_checker: bool = True,
210-
text_encoder_2=T5EncoderModel,
211-
tokenizer_2=MT5Tokenizer,
210+
text_encoder_2: Optional[T5EncoderModel] = None,
211+
tokenizer_2: Optional[MT5Tokenizer] = None,
212212
):
213213
super().__init__()
214214

0 commit comments

Comments
 (0)