Skip to content

Commit 8323240

Browse files
committed
Quality and style checks
1 parent 461ab73 commit 8323240

File tree

6 files changed

+59
-61
lines changed

6 files changed

+59
-61
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,19 @@
3333

3434

3535
if is_transformers_available():
36-
from transformers import (
37-
CLIPImageProcessor,
38-
CLIPVisionModelWithProjection,
39-
SiglipImageProcessor,
40-
SiglipVisionModel
41-
)
36+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
4237

4338
from ..models.attention_processor import (
4439
AttnProcessor,
4540
AttnProcessor2_0,
46-
JointAttnProcessor2_0,
4741
IPAdapterAttnProcessor,
4842
IPAdapterAttnProcessor2_0,
49-
IPAdapterXFormersAttnProcessor,
5043
IPAdapterJointAttnProcessor2_0,
44+
IPAdapterXFormersAttnProcessor,
45+
JointAttnProcessor2_0,
5146
)
5247

48+
5349
logger = logging.get_logger(__name__)
5450

5551

@@ -495,8 +491,10 @@ def load_ip_adapter(
495491
)
496492

497493
self.register_modules(
498-
feature_extractor = SiglipImageProcessor.from_pretrained(**args).to(self.device, dtype=self.dtype),
499-
image_encoder = SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype),
494+
feature_extractor=SiglipImageProcessor.from_pretrained(**args).to(
495+
self.device, dtype=self.dtype
496+
),
497+
image_encoder=SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype),
500498
)
501499
else:
502500
raise ValueError(
@@ -513,9 +511,9 @@ def load_ip_adapter(
513511

514512
def set_ip_adapter_scale(self, scale: float):
515513
"""
516-
Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image prompt, and 0.0
517-
only conditioned by the text prompt. Lowering this value encourages the model to produce more diverse images, but they
518-
may not be as aligned with the image prompt.
514+
Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image
515+
prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages the model to produce more
516+
diverse images, but they may not be as aligned with the image prompt.
519517
520518
Example:
521519
@@ -556,11 +554,7 @@ def unload_ip_adapter(self):
556554

557555
# Restore original attention processors layers
558556
attn_procs = {
559-
name: (
560-
JointAttnProcessor2_0()
561-
if isinstance(value, IPAdapterJointAttnProcessor2_0)
562-
else value.__class__()
563-
)
557+
name: (JointAttnProcessor2_0() if isinstance(value, IPAdapterJointAttnProcessor2_0) else value.__class__())
564558
for name, value in self.transformer.attn_processors.items()
565559
}
566-
self.transformer.set_attn_processor(attn_procs)
560+
self.transformer.set_attn_processor(attn_procs)

src/diffusers/models/attention.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,11 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188188
self._chunk_dim = dim
189189

190190
def forward(
191-
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
192-
joint_attention_kwargs: Dict[str, Any] = {}
191+
self,
192+
hidden_states: torch.FloatTensor,
193+
encoder_hidden_states: torch.FloatTensor,
194+
temb: torch.FloatTensor,
195+
joint_attention_kwargs: Dict[str, Any] = {},
193196
):
194197
if self.use_dual_attention:
195198
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
@@ -207,8 +210,9 @@ def forward(
207210

208211
# Attention.
209212
attn_output, context_attn_output = self.attn(
210-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
211-
**joint_attention_kwargs
213+
hidden_states=norm_hidden_states,
214+
encoder_hidden_states=norm_encoder_hidden_states,
215+
**joint_attention_kwargs,
212216
)
213217

214218
# Process attention outputs for the `hidden_states`.

src/diffusers/models/attention_processor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5047,7 +5047,7 @@ def __call__(
50475047
hidden_states = hidden_states / attn.rescale_output_factor
50485048

50495049
return hidden_states
5050-
5050+
50515051

50525052
class IPAdapterJointAttnProcessor2_0(torch.nn.Module):
50535053
"""Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections."""
@@ -5058,15 +5058,14 @@ def __init__(
50585058
ip_hidden_states_dim: int,
50595059
head_dim: int,
50605060
timesteps_emb_dim: int = 1280,
5061-
scale: float = 0.5
5061+
scale: float = 0.5,
50625062
):
50635063
super().__init__()
50645064

50655065
# To prevent circular import
5066-
from .normalization import RMSNorm, AdaLayerNorm
5066+
from .normalization import AdaLayerNorm, RMSNorm
50675067

5068-
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2,
5069-
norm_eps=1e-6, chunk_dim=1)
5068+
self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1)
50705069
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
50715070
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
50725071
self.norm_q = RMSNorm(head_dim, 1e-6)
@@ -5081,7 +5080,7 @@ def __call__(
50815080
encoder_hidden_states: torch.FloatTensor = None,
50825081
attention_mask: Optional[torch.FloatTensor] = None,
50835082
ip_hidden_states: torch.FloatTensor = None,
5084-
temb: torch.FloatTensor = None
5083+
temb: torch.FloatTensor = None,
50855084
) -> torch.FloatTensor:
50865085
residual = hidden_states
50875086

@@ -5170,7 +5169,9 @@ def __call__(
51705169
img_key = torch.cat([img_key, ip_key], dim=2)
51715170
img_value = torch.cat([img_value, ip_value], dim=2)
51725171

5173-
ip_hidden_states = F.scaled_dot_product_attention(img_query, img_key, img_value, dropout_p=0.0, is_causal=False)
5172+
ip_hidden_states = F.scaled_dot_product_attention(
5173+
img_query, img_key, img_value, dropout_p=0.0, is_causal=False
5174+
)
51745175
ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim)
51755176
ip_hidden_states = ip_hidden_states.to(img_query.dtype)
51765177

src/diffusers/models/embeddings.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,7 +2115,7 @@ def __init__(
21152115
) -> None:
21162116
super().__init__()
21172117

2118-
self.scale = dim_head ** -0.5
2118+
self.scale = dim_head**-0.5
21192119
self.dim_head = dim_head
21202120
self.heads = heads
21212121
inner_dim = dim_head * heads
@@ -2135,6 +2135,7 @@ def forward(self, x, latents, shift=None, scale=None):
21352135
latent (torch.Tensor): latent features
21362136
shape (b, n2, D)
21372137
"""
2138+
21382139
def reshape_tensor(x, heads):
21392140
bs, length, _ = x.shape
21402141
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
@@ -2169,7 +2170,7 @@ def reshape_tensor(x, heads):
21692170
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
21702171

21712172
return self.to_out(out)
2172-
2173+
21732174

21742175
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
21752176
class TimePerceiverResampler(nn.Module):
@@ -2188,12 +2189,12 @@ def __init__(
21882189
timestep_freq_shift: int = 0,
21892190
) -> None:
21902191
super().__init__()
2191-
2192-
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim ** 0.5)
2192+
2193+
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
21932194
self.proj_in = nn.Linear(embed_dim, hidden_dim)
21942195
self.proj_out = nn.Linear(hidden_dim, output_dim)
21952196
self.norm_out = nn.LayerNorm(output_dim)
2196-
2197+
21972198
ff_inner_dim = int(hidden_dim * ffn_ratio)
21982199
self.layers = nn.ModuleList([])
21992200
for _ in range(depth):
@@ -2210,10 +2211,7 @@ def __init__(
22102211
nn.Linear(ff_inner_dim, hidden_dim, bias=False),
22112212
),
22122213
# adaLN
2213-
nn.Sequential(
2214-
nn.SiLU(),
2215-
nn.Linear(hidden_dim, ff_inner_dim, bias=True)
2216-
)
2214+
nn.Sequential(nn.SiLU(), nn.Linear(hidden_dim, ff_inner_dim, bias=True)),
22172215
]
22182216
)
22192217
)
@@ -2227,7 +2225,7 @@ def forward(self, x, timestep, need_temb=False):
22272225
timestep_emb = self.time_embedding(timestep_emb, None)
22282226

22292227
latents = self.latents.repeat(x.size(0), 1, 1)
2230-
2228+
22312229
x = self.proj_in(x)
22322230
x = x + timestep_emb[:, None]
22332231

@@ -2242,7 +2240,7 @@ def forward(self, x, timestep, need_temb=False):
22422240
if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
22432241
latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
22442242
latents = latents + res
2245-
2243+
22462244
latents = self.proj_out(latents)
22472245
latents = self.norm_out(latents)
22482246

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
Attention,
2525
AttentionProcessor,
2626
FusedJointAttnProcessor2_0,
27-
JointAttnProcessor2_0,
2827
IPAdapterJointAttnProcessor2_0,
28+
JointAttnProcessor2_0,
2929
)
3030
from ...models.modeling_utils import ModelMixin, load_model_dict_into_meta
3131
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
@@ -376,7 +376,7 @@ def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool):
376376
hidden_dim=hidden_dim,
377377
heads=heads,
378378
num_queries=num_queries,
379-
timestep_in_dim=timestep_in_dim
379+
timestep_in_dim=timestep_in_dim,
380380
).to(device=self.device, dtype=self.dtype)
381381

382382
if not low_cpu_mem_usage:
@@ -470,7 +470,9 @@ def custom_forward(*inputs):
470470
)
471471
elif not is_skip:
472472
encoder_hidden_states, hidden_states = block(
473-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
473+
hidden_states=hidden_states,
474+
encoder_hidden_states=encoder_hidden_states,
475+
temb=temb,
474476
joint_attention_kwargs=joint_attention_kwargs,
475477
)
476478

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717

1818
import torch
1919
from transformers import (
20+
BaseImageProcessor,
2021
CLIPTextModelWithProjection,
2122
CLIPTokenizer,
23+
PreTrainedModel,
2224
T5EncoderModel,
2325
T5TokenizerFast,
24-
PreTrainedModel,
25-
BaseImageProcessor,
2626
)
2727

28-
from ...image_processor import VaeImageProcessor, PipelineImageInput
29-
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin, SD3IPAdapterMixin
28+
from ...image_processor import PipelineImageInput, VaeImageProcessor
29+
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
3030
from ...models.autoencoders import AutoencoderKL
3131
from ...models.transformers import SD3Transformer2DModel
3232
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -184,7 +184,7 @@ def __init__(
184184
text_encoder_3: T5EncoderModel,
185185
tokenizer_3: T5TokenizerFast,
186186
image_encoder: PreTrainedModel = None,
187-
feature_extractor: BaseImageProcessor = None
187+
feature_extractor: BaseImageProcessor = None,
188188
):
189189
super().__init__()
190190

@@ -199,7 +199,7 @@ def __init__(
199199
transformer=transformer,
200200
scheduler=scheduler,
201201
image_encoder=image_encoder,
202-
feature_extractor=feature_extractor
202+
feature_extractor=feature_extractor,
203203
)
204204
self.vae_scale_factor = (
205205
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -678,7 +678,7 @@ def num_timesteps(self):
678678
@property
679679
def interrupt(self):
680680
return self._interrupt
681-
681+
682682
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
683683
def encode_image(self, image):
684684
if not isinstance(image, torch.Tensor):
@@ -687,16 +687,18 @@ def encode_image(self, image):
687687
image = image.to(device=self.device, dtype=self.dtype)
688688

689689
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
690-
uncond_image_enc_hidden_states = self.image_encoder(torch.zeros_like(image), output_hidden_states=True).hidden_states[-2]
691-
690+
uncond_image_enc_hidden_states = self.image_encoder(
691+
torch.zeros_like(image), output_hidden_states=True
692+
).hidden_states[-2]
693+
692694
return image_enc_hidden_states, uncond_image_enc_hidden_states
693695

694696
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds
695697
def prepare_ip_adapter_image_embeds(
696698
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
697699
):
698700
if ip_adapter_image_embeds is None:
699-
single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image)
701+
single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image)
700702
else:
701703
for single_image_embeds in ip_adapter_image_embeds:
702704
if do_classifier_free_guidance:
@@ -705,13 +707,13 @@ def prepare_ip_adapter_image_embeds(
705707
single_image_embeds = ip_adapter_image_embeds
706708

707709
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
708-
710+
709711
if do_classifier_free_guidance:
710712
single_negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
711713
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
712714

713715
return single_image_embeds.to(device=device)
714-
716+
715717
@torch.no_grad()
716718
@replace_example_docstring(EXAMPLE_DOC_STRING)
717719
def __call__(
@@ -979,15 +981,12 @@ def __call__(
979981
need_temb=True,
980982
)
981983

982-
image_prompt_embeds = dict(
983-
ip_hidden_states=ip_hidden_states,
984-
temb=temb
985-
)
984+
image_prompt_embeds = {"ip_hidden_states": ip_hidden_states, "temb": temb}
986985

987986
if self.joint_attention_kwargs is None:
988987
self._joint_attention_kwargs = image_prompt_embeds
989988
else:
990-
self._joint_attention_kwargs.update(**image_prompt_embeds)
989+
self._joint_attention_kwargs.update(**image_prompt_embeds)
991990

992991
noise_pred = self.transformer(
993992
hidden_states=latent_model_input,

0 commit comments

Comments
 (0)