Skip to content

Commit 12eb38f

Browse files
committed
More type corrections and skip tokenizer type checking
1 parent 70efb66 commit 12eb38f

File tree

10 files changed

+81
-53
lines changed

10 files changed

+81
-53
lines changed

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
from transformers import (
20-
BaseImageProcessor,
20+
SiglipImageProcessor,
2121
CLIPTextModelWithProjection,
2222
CLIPTokenizer,
23-
PreTrainedModel,
23+
SiglipVisionModel,
2424
T5EncoderModel,
2525
T5TokenizerFast,
2626
)
@@ -178,9 +178,9 @@ class StableDiffusion3ControlNetPipeline(
178178
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
179179
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
180180
additional conditioning.
181-
image_encoder (`PreTrainedModel`, *optional*):
181+
image_encoder (`SiglipVisionModel`, *optional*):
182182
Pre-trained Vision Model for IP Adapter.
183-
feature_extractor (`BaseImageProcessor`, *optional*):
183+
feature_extractor (`SiglipImageProcessor`, *optional*):
184184
Image processor for IP Adapter.
185185
"""
186186

@@ -202,8 +202,8 @@ def __init__(
202202
controlnet: Union[
203203
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
204204
],
205-
image_encoder: PreTrainedModel = None,
206-
feature_extractor: BaseImageProcessor = None,
205+
image_encoder: Optional[SiglipVisionModel] = None,
206+
feature_extractor: Optional[SiglipImageProcessor] = None,
207207
):
208208
super().__init__()
209209
if isinstance(controlnet, (list, tuple)):

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
from transformers import (
20-
BaseImageProcessor,
20+
SiglipImageProcessor,
2121
CLIPTextModelWithProjection,
2222
CLIPTokenizer,
23-
PreTrainedModel,
23+
SiglipModel,
2424
T5EncoderModel,
2525
T5TokenizerFast,
2626
)
@@ -223,8 +223,8 @@ def __init__(
223223
controlnet: Union[
224224
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
225225
],
226-
image_encoder: PreTrainedModel = None,
227-
feature_extractor: BaseImageProcessor = None,
226+
image_encoder: SiglipModel = None,
227+
feature_extractor: Optional[SiglipImageProcessor] = None,
228228
):
229229
super().__init__()
230230

src/diffusers/pipelines/lumina2/pipeline_lumina2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import numpy as np
1919
import torch
20-
from transformers import AutoModel, AutoTokenizer
20+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
2121

2222
from ...image_processor import VaeImageProcessor
2323
from ...models import AutoencoderKL
@@ -150,11 +150,11 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline):
150150
Args:
151151
vae ([`AutoencoderKL`]):
152152
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
153-
text_encoder ([`AutoModel`]):
153+
text_encoder ([`PreTrainedModel`]):
154154
Frozen text-encoder. Lumina-T2I uses
155155
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
156156
[t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
157-
tokenizer (`AutoModel`):
157+
tokenizer (`PreTrainedTokenizerBase`):
158158
Tokenizer of class
159159
[AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
160160
transformer ([`Transformer2DModel`]):
@@ -172,8 +172,8 @@ def __init__(
172172
transformer: Lumina2Transformer2DModel,
173173
scheduler: FlowMatchEulerDiscreteScheduler,
174174
vae: AutoencoderKL,
175-
text_encoder: AutoModel,
176-
tokenizer: AutoTokenizer,
175+
text_encoder: PreTrainedModel,
176+
tokenizer: PreTrainedTokenizerBase,
177177
):
178178
super().__init__()
179179

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import PreTrainedModel, PreTrainedTokenizerBase
23+
from transformers import Gemma2PreTrainedModel, GemmaTokenizerFast, GemmaTokenizer
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2626
from ...image_processor import PixArtImageProcessor
@@ -160,8 +160,8 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
160160

161161
def __init__(
162162
self,
163-
tokenizer: PreTrainedTokenizerBase,
164-
text_encoder: PreTrainedModel,
163+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
164+
text_encoder: Gemma2PreTrainedModel,
165165
vae: AutoencoderDC,
166166
transformer: SanaTransformer2DModel,
167167
scheduler: FlowMatchEulerDiscreteScheduler,

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,12 +1047,20 @@ def get_detailed_type(obj: Any) -> Type:
10471047
else:
10481048
return obj_type
10491049

1050-
for key, class_obj in init_kwargs.items():
1051-
if "scheduler" in key:
1050+
for kw, arg in init_kwargs.items():
1051+
# Too complex to validate with type annotation alone
1052+
if "scheduler" in kw:
10521053
continue
1053-
1054-
if class_obj is not None and not is_valid_type(class_obj, expected_types[key]):
1055-
logger.warning(f"Expected types for {key}: {expected_types[key]}, got {get_detailed_type(class_obj)}.")
1054+
# Many tokenizer annotations don't include its "Fast" variant, so skip this
1055+
# e.g T5Tokenizer but not T5TokenizerFast
1056+
elif "tokenizer" in kw:
1057+
continue
1058+
elif (
1059+
arg is not None
1060+
and expected_types[kw] is not inspect.Signature.empty # no type annotations
1061+
and not is_valid_type(arg, expected_types[kw])
1062+
):
1063+
logger.warning(f"Expected types for {kw}: {expected_types[kw]}, got {get_detailed_type(arg)}.")
10561064

10571065
# 11. Instantiate the pipeline
10581066
model = pipeline_class(**init_kwargs)

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
23-
from transformers import PreTrainedModel, PreTrainedTokenizerBase
23+
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
2424

2525
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2626
from ...image_processor import PixArtImageProcessor
@@ -200,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
200200

201201
def __init__(
202202
self,
203-
tokenizer: PreTrainedTokenizerBase,
204-
text_encoder: PreTrainedModel,
203+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
204+
text_encoder: Gemma2PreTrainedModel,
205205
vae: AutoencoderDC,
206206
transformer: SanaTransformer2DModel,
207207
scheduler: DPMSolverMultistepScheduler,

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
from transformers import (
20-
BaseImageProcessor,
20+
SiglipImageProcessor,
2121
CLIPTextModelWithProjection,
2222
CLIPTokenizer,
23-
PreTrainedModel,
23+
SiglipVisionModel,
2424
T5EncoderModel,
2525
T5TokenizerFast,
2626
)
@@ -176,9 +176,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
176176
tokenizer_3 (`T5TokenizerFast`):
177177
Tokenizer of class
178178
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
179-
image_encoder (`PreTrainedModel`, *optional*):
179+
image_encoder (`SiglipVisionModel`, *optional*):
180180
Pre-trained Vision Model for IP Adapter.
181-
feature_extractor (`BaseImageProcessor`, *optional*):
181+
feature_extractor (`SiglipImageProcessor`, *optional*):
182182
Image processor for IP Adapter.
183183
"""
184184

@@ -197,8 +197,8 @@ def __init__(
197197
tokenizer_2: CLIPTokenizer,
198198
text_encoder_3: T5EncoderModel,
199199
tokenizer_3: T5TokenizerFast,
200-
image_encoder: PreTrainedModel = None,
201-
feature_extractor: BaseImageProcessor = None,
200+
image_encoder: SiglipVisionModel = None,
201+
feature_extractor: SiglipImageProcessor = None,
202202
):
203203
super().__init__()
204204

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import PIL.Image
1919
import torch
2020
from transformers import (
21-
BaseImageProcessor,
21+
SiglipImageProcessor,
2222
CLIPTextModelWithProjection,
2323
CLIPTokenizer,
24-
PreTrainedModel,
24+
SiglipVisionModel,
2525
T5EncoderModel,
2626
T5TokenizerFast,
2727
)
@@ -197,6 +197,10 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
197197
tokenizer_3 (`T5TokenizerFast`):
198198
Tokenizer of class
199199
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
200+
image_encoder (`SiglipVisionModel`, *optional*):
201+
Pre-trained Vision Model for IP Adapter.
202+
feature_extractor (`SiglipImageProcessor`, *optional*):
203+
Image processor for IP Adapter.
200204
"""
201205

202206
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
@@ -214,8 +218,8 @@ def __init__(
214218
tokenizer_2: CLIPTokenizer,
215219
text_encoder_3: T5EncoderModel,
216220
tokenizer_3: T5TokenizerFast,
217-
image_encoder: PreTrainedModel = None,
218-
feature_extractor: BaseImageProcessor = None,
221+
image_encoder: Optional[SiglipVisionModel] = None,
222+
feature_extractor: Optional[SiglipImageProcessor] = None,
219223
):
220224
super().__init__()
221225

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
import torch
1919
from transformers import (
20-
BaseImageProcessor,
20+
SiglipImageProcessor,
2121
CLIPTextModelWithProjection,
2222
CLIPTokenizer,
23-
PreTrainedModel,
23+
SiglipVisionModel,
2424
T5EncoderModel,
2525
T5TokenizerFast,
2626
)
@@ -196,9 +196,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
196196
tokenizer_3 (`T5TokenizerFast`):
197197
Tokenizer of class
198198
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
199-
image_encoder (`PreTrainedModel`, *optional*):
199+
image_encoder (`SiglipVisionModel`, *optional*):
200200
Pre-trained Vision Model for IP Adapter.
201-
feature_extractor (`BaseImageProcessor`, *optional*):
201+
feature_extractor (`SiglipImageProcessor`, *optional*):
202202
Image processor for IP Adapter.
203203
"""
204204

@@ -217,8 +217,8 @@ def __init__(
217217
tokenizer_2: CLIPTokenizer,
218218
text_encoder_3: T5EncoderModel,
219219
tokenizer_3: T5TokenizerFast,
220-
image_encoder: PreTrainedModel = None,
221-
feature_extractor: BaseImageProcessor = None,
220+
image_encoder: Optional[SiglipVisionModel] = None,
221+
feature_extractor: Optional[SiglipImageProcessor] = None,
222222
):
223223
super().__init__()
224224

src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,31 @@
1919
import torch
2020
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
2121
from k_diffusion.sampling import BrownianTreeNoiseSampler, get_sigmas_karras
22+
from transformers import (
23+
CLIPImageProcessor,
24+
CLIPTextModel,
25+
CLIPTokenizer,
26+
CLIPTokenizerFast,
27+
)
2228

2329
from ...image_processor import VaeImageProcessor
24-
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
30+
from ...loaders import (
31+
StableDiffusionLoraLoaderMixin,
32+
TextualInversionLoaderMixin,
33+
)
34+
from ...models import AutoencoderKL, UNet2DConditionModel
2535
from ...models.lora import adjust_lora_scale_text_encoder
26-
from ...schedulers import LMSDiscreteScheduler
27-
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
36+
from ...schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
37+
from ...utils import (
38+
USE_PEFT_BACKEND,
39+
deprecate,
40+
logging,
41+
scale_lora_layers,
42+
unscale_lora_layers,
43+
)
2844
from ...utils.torch_utils import randn_tensor
2945
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
30-
from ..stable_diffusion import StableDiffusionPipelineOutput
46+
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
3147

3248

3349
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -95,13 +111,13 @@ class StableDiffusionKDiffusionPipeline(
95111

96112
def __init__(
97113
self,
98-
vae,
99-
text_encoder,
100-
tokenizer,
101-
unet,
102-
scheduler,
103-
safety_checker,
104-
feature_extractor,
114+
vae: AutoencoderKL,
115+
text_encoder: CLIPTextModel,
116+
tokenizer: Union[CLIPTokenizer, CLIPTokenizerFast],
117+
unet: UNet2DConditionModel,
118+
scheduler: KarrasDiffusionSchedulers,
119+
safety_checker: StableDiffusionSafetyChecker,
120+
feature_extractor: CLIPImageProcessor,
105121
requires_safety_checker: bool = True,
106122
):
107123
super().__init__()

0 commit comments

Comments
 (0)