Skip to content

Commit 5ca27aa

Browse files
committed
Type checks subclasses and fixed type warnings
1 parent b1f26c5 commit 5ca27aa

File tree

17 files changed

+48
-38
lines changed

17 files changed

+48
-38
lines changed

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def __init__(
224224
vae: AutoencoderKL,
225225
text_encoder: CLIPTextModel,
226226
tokenizer: CLIPTokenizer,
227-
unet: UNet2DConditionModel,
227+
unet: Union[UNet2DConditionModel, UNetMotionModel],
228228
motion_adapter: MotionAdapter,
229229
scheduler: Union[
230230
DDIMScheduler,

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def __init__(
246246
vae: AutoencoderKL,
247247
text_encoder: CLIPTextModel,
248248
tokenizer: CLIPTokenizer,
249-
unet: UNet2DConditionModel,
249+
unet: Union[UNet2DConditionModel, UNetMotionModel],
250250
motion_adapter: MotionAdapter,
251251
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
252252
scheduler: Union[

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ def __init__(
232232
Tuple[HunyuanDiT2DControlNetModel],
233233
HunyuanDiT2DMultiControlNetModel,
234234
],
235-
text_encoder_2=T5EncoderModel,
236-
tokenizer_2=MT5Tokenizer,
235+
text_encoder_2: Optional[T5EncoderModel] = None,
236+
tokenizer_2: Optional[MT5Tokenizer] = None,
237237
requires_safety_checker: bool = True,
238238
):
239239
super().__init__()

src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from ...models import UNet1DModel
21+
from ...schedulers import SchedulerMixin
2022
from ...utils import is_torch_xla_available, logging
2123
from ...utils.torch_utils import randn_tensor
2224
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
@@ -49,7 +51,7 @@ class DanceDiffusionPipeline(DiffusionPipeline):
4951

5052
model_cpu_offload_seq = "unet"
5153

52-
def __init__(self, unet, scheduler):
54+
def __init__(self, unet: UNet1DModel, scheduler: SchedulerMixin):
5355
super().__init__()
5456
self.register_modules(unet=unet, scheduler=scheduler)
5557

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818

19+
from ...models import UNet2DModel
1920
from ...schedulers import DDIMScheduler
2021
from ...utils import is_torch_xla_available
2122
from ...utils.torch_utils import randn_tensor
@@ -47,7 +48,7 @@ class DDIMPipeline(DiffusionPipeline):
4748

4849
model_cpu_offload_seq = "unet"
4950

50-
def __init__(self, unet, scheduler):
51+
def __init__(self, unet: UNet2DModel, scheduler: DDIMScheduler):
5152
super().__init__()
5253

5354
# make sure scheduler can always be converted to DDIM

src/diffusers/pipelines/ddpm/pipeline_ddpm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import torch
1919

20+
from ...models import UNet2DModel
21+
from ...schedulers import DDPMScheduler
2022
from ...utils import is_torch_xla_available
2123
from ...utils.torch_utils import randn_tensor
2224
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -47,7 +49,7 @@ class DDPMPipeline(DiffusionPipeline):
4749

4850
model_cpu_offload_seq = "unet"
4951

50-
def __init__(self, unet, scheduler):
52+
def __init__(self, unet: UNet2DModel, scheduler: DDPMScheduler):
5153
super().__init__()
5254
self.register_modules(unet=unet, scheduler=scheduler)
5355

src/diffusers/pipelines/deprecated/repaint/pipeline_repaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class RePaintPipeline(DiffusionPipeline):
9191
scheduler: RePaintScheduler
9292
model_cpu_offload_seq = "unet"
9393

94-
def __init__(self, unet, scheduler):
94+
def __init__(self, unet: UNet2DModel, scheduler: RePaintScheduler):
9595
super().__init__()
9696
self.register_modules(unet=unet, scheduler=scheduler)
9797

src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def __init__(
207207
safety_checker: StableDiffusionSafetyChecker,
208208
feature_extractor: CLIPImageProcessor,
209209
requires_safety_checker: bool = True,
210-
text_encoder_2=T5EncoderModel,
211-
tokenizer_2=MT5Tokenizer,
210+
text_encoder_2: Optional[T5EncoderModel] = None,
211+
tokenizer_2: Optional[MT5Tokenizer] = None,
212212
):
213213
super().__init__()
214214

src/diffusers/pipelines/lumina/pipeline_lumina.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import AutoModel, AutoTokenizer
23+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
2424

2525
from ...image_processor import VaeImageProcessor
2626
from ...models import AutoencoderKL
@@ -143,13 +143,13 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
143143
Args:
144144
vae ([`AutoencoderKL`]):
145145
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
146-
text_encoder ([`AutoModel`]):
146+
text_encoder ([`PreTrainedModel`]):
147147
Frozen text-encoder. Lumina-T2I uses
148148
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
149149
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
150-
tokenizer (`AutoModel`):
150+
tokenizer (`AutoTokenizer`):
151151
Tokenizer of class
152-
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
152+
[AutoTokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
153153
transformer ([`Transformer2DModel`]):
154154
A text conditioned `Transformer2DModel` to denoise the encoded image latents.
155155
scheduler ([`SchedulerMixin`]):
@@ -180,8 +180,8 @@ def __init__(
180180
transformer: LuminaNextDiT2DModel,
181181
scheduler: FlowMatchEulerDiscreteScheduler,
182182
vae: AutoencoderKL,
183-
text_encoder: AutoModel,
184-
tokenizer: AutoTokenizer,
183+
text_encoder: PreTrainedModel,
184+
tokenizer: PreTrainedTokenizerBase,
185185
):
186186
super().__init__()
187187

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2626
from ...image_processor import PixArtImageProcessor
@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
160160

161161
def __init__(
162162
self,
163-
tokenizer: AutoTokenizer,
164-
text_encoder: AutoModelForCausalLM,
163+
tokenizer: PreTrainedTokenizerBase,
164+
text_encoder: PreTrainedModel,
165165
vae: AutoencoderDC,
166166
transformer: SanaTransformer2DModel,
167167
scheduler: FlowMatchEulerDiscreteScheduler,

0 commit comments

Comments
 (0)