Skip to content

Commit 89c4e63

Browse files
committed
Quality and style checks
1 parent 8323240 commit 89c4e63

File tree

1 file changed

+53
-73
lines changed

1 file changed

+53
-73
lines changed

examples/community/pipeline_stable_diffusion_3_ipa.py

Lines changed: 53 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import math
1617
from typing import Any, Callable, Dict, List, Optional, Union
1718

1819
import torch
1920
import torch.nn as nn
2021
import torch.nn.functional as F
22+
from einops import rearrange
2123
from transformers import (
2224
CLIPTextModelWithProjection,
2325
CLIPTokenizer,
@@ -28,6 +30,11 @@
2830
from diffusers.image_processor import VaeImageProcessor
2931
from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
3032
from diffusers.models.autoencoders import AutoencoderKL
33+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
34+
from diffusers.models.normalization import RMSNorm
35+
from diffusers.models.transformers import SD3Transformer2DModel
36+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37+
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
3138
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
3239
from diffusers.utils import (
3340
USE_PEFT_BACKEND,
@@ -38,15 +45,6 @@
3845
unscale_lora_layers,
3946
)
4047
from diffusers.utils.torch_utils import randn_tensor
41-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42-
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
43-
44-
from diffusers.models.transformers import SD3Transformer2DModel
45-
from diffusers.models.normalization import RMSNorm
46-
from einops import rearrange
47-
import math
48-
49-
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
5048

5149

5250
if is_torch_xla_available():
@@ -86,10 +84,10 @@ def FeedForward(dim, mult=4):
8684
nn.Linear(inner_dim, dim, bias=False),
8785
)
8886

89-
87+
9088
def reshape_tensor(x, heads):
9189
bs, length, width = x.shape
92-
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
90+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
9391
x = x.view(bs, length, heads, -1)
9492
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
9593
x = x.transpose(1, 2)
@@ -113,7 +111,6 @@ def __init__(self, *, dim, dim_head=64, heads=8):
113111
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
114112
self.to_out = nn.Linear(inner_dim, dim, bias=False)
115113

116-
117114
def forward(self, x, latents, shift=None, scale=None):
118115
"""
119116
Args:
@@ -127,23 +124,23 @@ def forward(self, x, latents, shift=None, scale=None):
127124

128125
if shift is not None and scale is not None:
129126
latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
130-
127+
131128
b, l, _ = latents.shape
132129

133130
q = self.to_q(latents)
134131
kv_input = torch.cat((x, latents), dim=-2)
135132
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
136-
133+
137134
q = reshape_tensor(q, self.heads)
138135
k = reshape_tensor(k, self.heads)
139136
v = reshape_tensor(v, self.heads)
140137

141138
# attention
142139
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
143-
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
140+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
144141
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
145142
out = weight @ v
146-
143+
147144
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
148145

149146
return self.to_out(out)
@@ -166,14 +163,14 @@ def __init__(
166163
timestep_freq_shift=0,
167164
):
168165
super().__init__()
169-
166+
170167
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
171-
168+
172169
self.proj_in = nn.Linear(embedding_dim, dim)
173170

174171
self.proj_out = nn.Linear(dim, output_dim)
175172
self.norm_out = nn.LayerNorm(output_dim)
176-
173+
177174
self.layers = nn.ModuleList([])
178175
for _ in range(depth):
179176
self.layers.append(
@@ -184,7 +181,7 @@ def __init__(
184181
# ff
185182
FeedForward(dim=dim, mult=ff_mult),
186183
# adaLN
187-
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True))
184+
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True)),
188185
]
189186
)
190187
)
@@ -199,12 +196,11 @@ def __init__(
199196
# nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
200197
# )
201198

202-
203199
def forward(self, x, timestep, need_temb=False):
204200
timestep_emb = self.embedding_time(x, timestep) # bs, dim
205201

206202
latents = self.latents.repeat(x.size(0), 1, 1)
207-
203+
208204
x = self.proj_in(x)
209205
x = x + timestep_emb[:, None]
210206

@@ -221,7 +217,7 @@ def forward(self, x, timestep, need_temb=False):
221217
latents = latents + res
222218

223219
# latents = ff(latents) + latents
224-
220+
225221
latents = self.proj_out(latents)
226222
latents = self.norm_out(latents)
227223

@@ -230,10 +226,7 @@ def forward(self, x, timestep, need_temb=False):
230226
else:
231227
return latents
232228

233-
234-
235229
def embedding_time(self, sample, timestep):
236-
237230
# 1. time
238231
timesteps = timestep
239232
if not torch.is_tensor(timesteps):
@@ -271,32 +264,29 @@ class AdaLayerNorm(nn.Module):
271264
num_embeddings (`int`): The size of the embeddings dictionary.
272265
"""
273266

274-
def __init__(self, embedding_dim: int, time_embedding_dim=None, mode='normal'):
267+
def __init__(self, embedding_dim: int, time_embedding_dim=None, mode="normal"):
275268
super().__init__()
276269

277270
self.silu = nn.SiLU()
278-
num_params_dict = dict(
279-
zero=6,
280-
normal=2,
281-
)
282-
num_params = num_params_dict[mode]
271+
272+
num_params = 2 if mode == "normal" else 6
283273
self.linear = nn.Linear(time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True)
284274
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
285275
self.mode = mode
286276

287277
def forward(
288278
self,
289279
x,
290-
hidden_dtype = None,
291-
emb = None,
280+
hidden_dtype=None,
281+
emb=None,
292282
):
293283
emb = self.linear(self.silu(emb))
294-
if self.mode == 'normal':
284+
if self.mode == "normal":
295285
shift_msa, scale_msa = emb.chunk(2, dim=1)
296286
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
297287
return x
298288

299-
elif self.mode == 'zero':
289+
elif self.mode == "zero":
300290
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
301291
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
302292
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
@@ -323,7 +313,6 @@ def __init__(
323313
self.norm_k = RMSNorm(head_dim, 1e-6)
324314
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
325315

326-
327316
def __call__(
328317
self,
329318
attn,
@@ -396,9 +385,8 @@ def __call__(
396385
if not attn.context_pre_only:
397386
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
398387

399-
400388
# IPadapter
401-
ip_hidden_states = emb_dict.get('ip_hidden_states', None)
389+
ip_hidden_states = emb_dict.get("ip_hidden_states", None)
402390
ip_hidden_states = self.get_ip_hidden_states(
403391
attn,
404392
img_query,
@@ -407,11 +395,10 @@ def __call__(
407395
img_value,
408396
None,
409397
None,
410-
emb_dict['temb'],
398+
emb_dict["temb"],
411399
)
412400
if ip_hidden_states is not None:
413-
hidden_states = hidden_states + ip_hidden_states * emb_dict.get('scale', 1.0)
414-
401+
hidden_states = hidden_states + ip_hidden_states * emb_dict.get("scale", 1.0)
415402

416403
# linear proj
417404
hidden_states = attn.to_out[0](hidden_states)
@@ -423,12 +410,13 @@ def __call__(
423410
else:
424411
return hidden_states
425412

426-
427-
def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None):
413+
def get_ip_hidden_states(
414+
self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None
415+
):
428416
if ip_hidden_states is None:
429417
return None
430-
431-
if not hasattr(self, 'to_k_ip') or not hasattr(self, 'to_v_ip'):
418+
419+
if not hasattr(self, "to_k_ip") or not hasattr(self, "to_v_ip"):
432420
return None
433421

434422
# norm ip input
@@ -439,11 +427,11 @@ def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_
439427
ip_value = self.to_v_ip(norm_ip_hidden_states)
440428

441429
# reshape
442-
query = rearrange(query, 'b l (h d) -> b h l d', h=attn.heads)
443-
img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads)
444-
img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads)
445-
ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads)
446-
ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads)
430+
query = rearrange(query, "b l (h d) -> b h l d", h=attn.heads)
431+
img_key = rearrange(img_key, "b l (h d) -> b h l d", h=attn.heads)
432+
img_value = rearrange(img_value, "b l (h d) -> b h l d", h=attn.heads)
433+
ip_key = rearrange(ip_key, "b l (h d) -> b h l d", h=attn.heads)
434+
ip_value = rearrange(ip_value, "b l (h d) -> b h l d", h=attn.heads)
447435

448436
# norm
449437
query = self.norm_q(query)
@@ -454,9 +442,9 @@ def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_
454442
key = torch.cat([img_key, ip_key], dim=2)
455443
value = torch.cat([img_value, ip_value], dim=2)
456444

457-
#
445+
#
458446
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
459-
ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)')
447+
ip_hidden_states = rearrange(ip_hidden_states, "b h l d -> b l (h d)")
460448
ip_hidden_states = ip_hidden_states.to(query.dtype)
461449
return ip_hidden_states
462450

@@ -1049,10 +1037,10 @@ def num_timesteps(self):
10491037
def interrupt(self):
10501038
return self._interrupt
10511039

1052-
10531040
@torch.inference_mode()
10541041
def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_dim=2432):
1055-
from transformers import SiglipVisionModel, SiglipImageProcessor
1042+
from transformers import SiglipImageProcessor, SiglipVisionModel
1043+
10561044
state_dict = torch.load(ip_adapter_path, map_location="cpu")
10571045

10581046
device, dtype = self.transformer.device, self.transformer.dtype
@@ -1084,14 +1072,13 @@ def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_d
10841072

10851073
self.image_proj_model = image_proj_model
10861074

1087-
10881075
attn_procs = {}
10891076
transformer = self.transformer
10901077
for idx_name, name in enumerate(transformer.attn_processors.keys()):
10911078
hidden_size = transformer.config.attention_head_dim * transformer.config.num_attention_heads
10921079
ip_hidden_states_dim = transformer.config.attention_head_dim * transformer.config.num_attention_heads
10931080
ip_encoder_hidden_states_dim = transformer.config.caption_projection_dim
1094-
1081+
10951082
attn_procs[name] = JointIPAttnProcessor(
10961083
hidden_size=hidden_size,
10971084
cross_attention_dim=transformer.config.caption_projection_dim,
@@ -1107,10 +1094,8 @@ def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_d
11071094
key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
11081095
print(f"=> loading ip_adapter: {key_name}")
11091096

1110-
11111097
@torch.inference_mode()
11121098
def encode_clip_image_emb(self, clip_image, device, dtype):
1113-
11141099
# clip
11151100
clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
11161101
clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
@@ -1119,8 +1104,6 @@ def encode_clip_image_emb(self, clip_image, device, dtype):
11191104

11201105
return clip_image_embeds
11211106

1122-
1123-
11241107
@torch.no_grad()
11251108
@replace_example_docstring(EXAMPLE_DOC_STRING)
11261109
def __call__(
@@ -1150,7 +1133,6 @@ def __call__(
11501133
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
11511134
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
11521135
max_sequence_length: int = 256,
1153-
11541136
# ipa
11551137
clip_image=None,
11561138
ipadapter_scale=1.0,
@@ -1349,18 +1331,16 @@ def __call__(
13491331
timestep = t.expand(latent_model_input.shape[0])
13501332

13511333
image_prompt_embeds, timestep_emb = self.image_proj_model(
1352-
clip_image_embeds,
1353-
timestep.to(dtype=latents.dtype),
1354-
need_temb=True
1334+
clip_image_embeds, timestep.to(dtype=latents.dtype), need_temb=True
13551335
)
13561336

1357-
joint_attention_kwargs = dict(
1358-
emb_dict=dict(
1359-
ip_hidden_states=image_prompt_embeds,
1360-
temb=timestep_emb,
1361-
scale=ipadapter_scale,
1362-
)
1363-
)
1337+
joint_attention_kwargs = {
1338+
"emb_dict": {
1339+
"ip_hidden_states": image_prompt_embeds,
1340+
"temb": timestep_emb,
1341+
"scale": ipadapter_scale,
1342+
}
1343+
}
13641344

13651345
noise_pred = self.transformer(
13661346
hidden_states=latent_model_input,
@@ -1420,4 +1400,4 @@ def __call__(
14201400
if not return_dict:
14211401
return (image,)
14221402

1423-
return StableDiffusion3PipelineOutput(images=image)
1403+
return StableDiffusion3PipelineOutput(images=image)

0 commit comments

Comments
 (0)