Skip to content

Commit 68169f8

Browse files
committed
Updated dosctrings and doc entries
1 parent 27fe083 commit 68169f8

File tree

7 files changed

+196
-38
lines changed

7 files changed

+196
-38
lines changed

docs/source/en/api/attnprocessor.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,6 @@ An attention processor is a class for applying different types of attention mech
5252

5353
## AttnProcessorNPU
5454
[[autodoc]] models.attention_processor.AttnProcessorNPU
55+
56+
## IPAdapterJointAttnProcessor2_0
57+
[[autodoc]] models.attention_processor.IPAdapterJointAttnProcessor2_0

docs/source/en/api/loaders/ip_adapter.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
2424

2525
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
2626

27+
## SD3IPAdapterMixin
28+
29+
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
30+
- all
31+
- is_ip_adapter_active
32+
2733
## IPAdapterMaskProcessor
2834

2935
[[autodoc]] image_processor.IPAdapterMaskProcessor

src/diffusers/loaders/ip_adapter.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,13 @@ class SD3IPAdapterMixin:
358358

359359
@property
360360
def is_ip_adapter_active(self) -> bool:
361-
r"""Checks if any ip_adapter attention processor have scale > 0.
361+
"""Checks if IP-Adapter is loaded and scale > 0.
362362
363363
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
364-
image is irrelevant.
364+
the image context is irrelevant.
365365
366366
Returns:
367-
`bool`: True when ip_adapter is loaded and any ip_adapter layer scale > 0.
367+
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
368368
"""
369369
scales = [
370370
attn_proc.scale
@@ -382,7 +382,7 @@ def load_ip_adapter(
382382
weight_name: str,
383383
image_encoder_folder: Optional[str] = "image_encoder",
384384
**kwargs,
385-
):
385+
) -> None:
386386
"""
387387
Parameters:
388388
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -500,19 +500,19 @@ def load_ip_adapter(
500500
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
501501

502502
# Commons args for loading image encoder and image processor
503-
args = dict(
504-
pretrained_model_name_or_path_or_dict,
505-
subfolder=image_encoder_subfolder,
506-
low_cpu_mem_usage=low_cpu_mem_usage,
507-
cache_dir=cache_dir,
508-
local_files_only=local_files_only,
509-
)
503+
kwargs = {
504+
"low_cpu_mem_usage": low_cpu_mem_usage,
505+
"cache_dir": cache_dir,
506+
"local_files_only": local_files_only,
507+
}
510508

511509
self.register_modules(
512-
feature_extractor=SiglipImageProcessor.from_pretrained(**args).to(
510+
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
511+
self.device, dtype=self.dtype
512+
),
513+
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
513514
self.device, dtype=self.dtype
514515
),
515-
image_encoder=SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype),
516516
)
517517
else:
518518
raise ValueError(
@@ -527,11 +527,11 @@ def load_ip_adapter(
527527
# Load IP-Adapter into transformer
528528
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
529529

530-
def set_ip_adapter_scale(self, scale: float):
530+
def set_ip_adapter_scale(self, scale: float) -> None:
531531
"""
532-
Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image
533-
prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages the model to produce more
534-
diverse images, but they may not be as aligned with the image prompt.
532+
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
533+
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
534+
the model to produce more diverse images, but they may not be as aligned with the image prompt.
535535
536536
Example:
537537
@@ -540,12 +540,17 @@ def set_ip_adapter_scale(self, scale: float):
540540
>>> pipeline.set_ip_adapter_scale(0.6)
541541
>>> ...
542542
```
543+
544+
Args:
545+
scale (float):
546+
IP-Adapter scale to be set.
547+
543548
"""
544549
for attn_processor in self.transformer.attn_processors.values():
545550
if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0):
546551
attn_processor.scale = scale
547552

548-
def unload_ip_adapter(self):
553+
def unload_ip_adapter(self) -> None:
549554
"""
550555
Unloads the IP Adapter weights.
551556

src/diffusers/models/attention_processor.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5149,7 +5149,22 @@ def __call__(
51495149

51505150

51515151
class IPAdapterJointAttnProcessor2_0(torch.nn.Module):
5152-
"""Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections."""
5152+
"""
5153+
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
5154+
additional image-based information and timestep embeddings.
5155+
5156+
Args:
5157+
hidden_size (`int`):
5158+
The number of hidden channels.
5159+
ip_hidden_states_dim (`int`):
5160+
The image feature dimension.
5161+
head_dim (`int`):
5162+
The number of head channels.
5163+
timesteps_emb_dim (`int`, defaults to 1280):
5164+
The number of input channels for timestep embedding.
5165+
scale (`float`, defaults to 0.5):
5166+
IP-Adapter scale.
5167+
"""
51535168

51545169
def __init__(
51555170
self,
@@ -5181,6 +5196,28 @@ def __call__(
51815196
ip_hidden_states: torch.FloatTensor = None,
51825197
temb: torch.FloatTensor = None,
51835198
) -> torch.FloatTensor:
5199+
"""
5200+
Perform the attention computation, integrating image features (if provided) and timestep embeddings.
5201+
5202+
If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0.
5203+
5204+
Args:
5205+
attn (`Attention`):
5206+
Attention instance.
5207+
hidden_states (`torch.FloatTensor`):
5208+
Input `hidden_states`.
5209+
encoder_hidden_states (`torch.FloatTensor`, *optional*):
5210+
The encoder hidden states.
5211+
attention_mask (`torch.FloatTensor`, *optional*):
5212+
Attention mask.
5213+
ip_hidden_states (`torch.FloatTensor`, *optional*):
5214+
Image embeddings.
5215+
temb (`torch.FloatTensor`, *optional*):
5216+
Timestep embeddings.
5217+
5218+
Returns:
5219+
`torch.FloatTensor`: Output hidden states.
5220+
"""
51845221
residual = hidden_states
51855222

51865223
batch_size = hidden_states.shape[0]

src/diffusers/models/embeddings.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,6 +2119,19 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
21192119

21202120

21212121
class IPAdapterTimeImageProjectionBlock(nn.Module):
2122+
"""Block for IPAdapterTimeImageProjection.
2123+
2124+
Args:
2125+
hidden_dim (`int`, defaults to 1280):
2126+
The number of hidden channels.
2127+
dim_head (`int`, defaults to 64):
2128+
The number of head channels.
2129+
heads (`int`, defaults to 20):
2130+
Parallel attention heads.
2131+
ffn_ratio (`int`, defaults to 4):
2132+
The expansion ratio of feedforward network hidden layer channels.
2133+
"""
2134+
21222135
def __init__(
21232136
self,
21242137
hidden_dim: int = 1280,
@@ -2152,7 +2165,21 @@ def __init__(
21522165
self.attn.to_k = None
21532166
self.attn.to_v = None
21542167

2155-
def forward(self, x, latents, timestep_emb):
2168+
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
2169+
"""Forward pass.
2170+
2171+
Args:
2172+
x (`torch.Tensor`):
2173+
Image features.
2174+
latents (`torch.Tensor`):
2175+
Latent features.
2176+
timestep_emb (`torch.Tensor`):
2177+
Timestep embedding.
2178+
2179+
Returns:
2180+
`torch.Tensor`: Output latent features.
2181+
"""
2182+
21562183
# Shift and scale for AdaLayerNorm
21572184
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
21582185
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
@@ -2192,6 +2219,33 @@ def forward(self, x, latents, timestep_emb):
21922219

21932220
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
21942221
class IPAdapterTimeImageProjection(nn.Module):
2222+
"""Resampler of SD3 IP-Adapter with timestep embedding.
2223+
2224+
Args:
2225+
embed_dim (`int`, defaults to 1152):
2226+
The feature dimension.
2227+
output_dim (`int`, defaults to 2432):
2228+
The number of output channels.
2229+
hidden_dim (`int`, defaults to 1280):
2230+
The number of hidden channels.
2231+
depth (`int`, defaults to 4):
2232+
The number of blocks.
2233+
dim_head (`int`, defaults to 64):
2234+
The number of head channels.
2235+
heads (`int`, defaults to 20):
2236+
Parallel attention heads.
2237+
num_queries (`int`, defaults to 64):
2238+
The number of queries.
2239+
ffn_ratio (`int`, defaults to 4):
2240+
The expansion ratio of feedforward network hidden layer channels.
2241+
timestep_in_dim (`int`, defaults to 320):
2242+
The number of input channels for timestep embedding.
2243+
timestep_flip_sin_to_cos (`bool`, defaults to True):
2244+
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
2245+
timestep_freq_shift (`int`, defaults to 0):
2246+
Controls the timestep delta between frequencies between dimensions.
2247+
"""
2248+
21952249
def __init__(
21962250
self,
21972251
embed_dim: int = 1152,
@@ -2217,7 +2271,17 @@ def __init__(
22172271
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
22182272
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
22192273

2220-
def forward(self, x, timestep):
2274+
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
2275+
"""Forward pass.
2276+
2277+
Args:
2278+
x (`torch.Tensor`):
2279+
Image features.
2280+
timestep (`torch.Tensor`):
2281+
Timestep in denoising process.
2282+
Returns:
2283+
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
2284+
"""
22212285
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
22222286
timestep_emb = self.time_embedding(timestep_emb)
22232287

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,19 @@ def _set_gradient_checkpointing(self, module, value=False):
331331
if hasattr(module, "gradient_checkpointing"):
332332
module.gradient_checkpointing = value
333333

334-
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool):
334+
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool) -> None:
335+
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
336+
337+
Args:
338+
state_dict (`Dict`):
339+
PyTorch state dict with keys "ip_adapter", which contains parameters for attention processors, and
340+
"image_proj", which contains parameters for image projection net.
341+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
342+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
343+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
344+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
345+
argument to `True` will raise an error.
346+
"""
335347
# IP-Adapter cross attention parameters
336348
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
337349
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,16 @@ def interrupt(self):
680680
return self._interrupt
681681

682682
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
683-
def encode_image(self, image):
683+
def encode_image(self, image: PipelineImageInput) -> torch.Tensor:
684+
"""Encodes the given image into a feature representation using a pre-trained image encoder.
685+
686+
Args:
687+
image (`PipelineImageInput`):
688+
Input image to be encoded.
689+
690+
Returns:
691+
`torch.Tensor`: The encoded image feature representation.
692+
"""
684693
if not isinstance(image, torch.Tensor):
685694
image = self.feature_extractor(image, return_tensors="pt").pixel_values
686695

@@ -690,17 +699,42 @@ def encode_image(self, image):
690699

691700
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
692701
def prepare_ip_adapter_image_embeds(
693-
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
694-
):
695-
if ip_adapter_image_embeds is None:
702+
self,
703+
ip_adapter_image: Optional[PipelineImageInput] = None,
704+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
705+
device: Optional[torch.device] = None,
706+
num_images_per_prompt: int = 1,
707+
do_classifier_free_guidance: bool = True,
708+
) -> torch.Tensor:
709+
"""Prepares image embeddings for use in the IP-Adapter.
710+
711+
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
712+
713+
Args:
714+
ip_adapter_image (`PipelineImageInput`, *optional*):
715+
The input image to extract features from for IP-Adapter.
716+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
717+
Precomputed image embeddings.
718+
device: (`torch.device`, *optional*):
719+
Torch device.
720+
num_images_per_prompt (`int`, defaults to 1):
721+
Number of images that should be generated per prompt.
722+
do_classifier_free_guidance (`bool`, defaults to True):
723+
Whether to use classifier free guidance or not.
724+
"""
725+
device = device or self._execution_device
726+
727+
if ip_adapter_image_embeds is not None:
728+
if do_classifier_free_guidance:
729+
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
730+
else:
731+
single_image_embeds = ip_adapter_image_embeds
732+
elif ip_adapter_image is not None:
696733
single_image_embeds = self.encode_image(ip_adapter_image)
697734
if do_classifier_free_guidance:
698735
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
699736
else:
700-
if do_classifier_free_guidance:
701-
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
702-
else:
703-
single_image_embeds = ip_adapter_image_embeds
737+
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
704738

705739
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
706740

@@ -733,7 +767,7 @@ def __call__(
733767
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
734768
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
735769
ip_adapter_image: Optional[PipelineImageInput] = None,
736-
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
770+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
737771
output_type: Optional[str] = "pil",
738772
return_dict: bool = True,
739773
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -810,11 +844,10 @@ def __call__(
810844
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
811845
input argument.
812846
ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
813-
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
814-
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
815-
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
816-
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
817-
provided, embeddings are computed from the `ip_adapter_image` input argument.
847+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
848+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
849+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
850+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
818851
output_type (`str`, *optional*, defaults to `"pil"`):
819852
The output format of the generate image. Choose between
820853
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -950,8 +983,6 @@ def __call__(
950983
)
951984

952985
# 6. Prepare image embeddings
953-
# Either image is passed and ip_adapter is active
954-
# Or image_embeds are passed directly
955986
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
956987
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
957988
ip_adapter_image,

0 commit comments

Comments
 (0)