Skip to content

Commit d9c1683

Browse files
committed
make ip adapter processor compatible with attention dispatcher
1 parent 4f52e34 commit d9c1683

File tree

1 file changed

+59
-83
lines changed

1 file changed

+59
-83
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 59 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -42,39 +42,42 @@
4242
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4343

4444

45-
class FluxAttnProcessor:
46-
_attention_backend = None
45+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
46+
query = attn.to_q(hidden_states)
47+
key = attn.to_k(hidden_states)
48+
value = attn.to_v(hidden_states)
4749

48-
def __init__(self):
49-
if not hasattr(F, "scaled_dot_product_attention"):
50-
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
50+
encoder_query = encoder_key = encoder_value = None
51+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
52+
encoder_query = attn.add_q_proj(encoder_hidden_states)
53+
encoder_key = attn.add_k_proj(encoder_hidden_states)
54+
encoder_value = attn.add_v_proj(encoder_hidden_states)
5155

52-
def _get_projections(self, attn, hidden_states, encoder_hidden_states=None):
53-
query = attn.to_q(hidden_states)
54-
key = attn.to_k(hidden_states)
55-
value = attn.to_v(hidden_states)
56+
return query, key, value, encoder_query, encoder_key, encoder_value
5657

57-
encoder_query = encoder_key = encoder_value = None
58-
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
59-
encoder_query = attn.add_q_proj(encoder_hidden_states)
60-
encoder_key = attn.add_k_proj(encoder_hidden_states)
61-
encoder_value = attn.add_v_proj(encoder_hidden_states)
6258

63-
return query, key, value, encoder_query, encoder_key, encoder_value
59+
def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
60+
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
6461

65-
def _get_fused_projections(self, attn, hidden_states, encoder_hidden_states=None):
66-
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
62+
encoder_query = encoder_key = encoder_value = (None,)
63+
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
64+
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
6765

68-
encoder_query = encoder_key = encoder_value = (None,)
69-
if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
70-
encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
66+
return query, key, value, encoder_query, encoder_key, encoder_value
7167

72-
return query, key, value, encoder_query, encoder_key, encoder_value
7368

74-
def get_qkv_projections(self, attn: AttentionModuleMixin, hidden_states, encoder_hidden_states=None):
75-
if attn.fused_projections:
76-
return self._get_fused_projections(attn, hidden_states, encoder_hidden_states)
77-
return self._get_projections(attn, hidden_states, encoder_hidden_states)
69+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
70+
if attn.fused_projections:
71+
return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
72+
return _get_projections(attn, hidden_states, encoder_hidden_states)
73+
74+
75+
class FluxAttnProcessor:
76+
_attention_backend = None
77+
78+
def __init__(self):
79+
if not hasattr(F, "scaled_dot_product_attention"):
80+
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
7881

7982
def __call__(
8083
self,
@@ -84,7 +87,7 @@ def __call__(
8487
attention_mask: Optional[torch.Tensor] = None,
8588
image_rotary_emb: Optional[torch.Tensor] = None,
8689
) -> torch.Tensor:
87-
query, key, value, encoder_query, encoder_key, encoder_value = self.get_qkv_projections(
90+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
8891
attn, hidden_states, encoder_hidden_states
8992
)
9093

@@ -180,55 +183,35 @@ def __call__(
180183
ip_hidden_states: Optional[List[torch.Tensor]] = None,
181184
ip_adapter_masks: Optional[torch.Tensor] = None,
182185
) -> torch.Tensor:
183-
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
186+
batch_size = hidden_states.shape[0]
184187

185-
# `sample` projections.
186-
hidden_states_query_proj = attn.to_q(hidden_states)
187-
key = attn.to_k(hidden_states)
188-
value = attn.to_v(hidden_states)
189-
190-
inner_dim = key.shape[-1]
191-
head_dim = inner_dim // attn.heads
188+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
189+
attn, hidden_states, encoder_hidden_states
190+
)
192191

193-
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
195-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
192+
query = query.unflatten(-1, (attn.heads, -1))
193+
key = key.unflatten(-1, (attn.heads, -1))
194+
value = value.unflatten(-1, (attn.heads, -1))
196195

197-
if attn.norm_q is not None:
198-
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
199-
if attn.norm_k is not None:
200-
key = attn.norm_k(key)
196+
query = attn.norm_q(query)
197+
key = attn.norm_k(key)
198+
ip_query = query
201199

202-
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
203200
if encoder_hidden_states is not None:
204-
# `context` projections.
205-
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
206-
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
207-
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
208-
209-
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
210-
batch_size, -1, attn.heads, head_dim
211-
).transpose(1, 2)
212-
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
213-
batch_size, -1, attn.heads, head_dim
214-
).transpose(1, 2)
215-
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
216-
batch_size, -1, attn.heads, head_dim
217-
).transpose(1, 2)
218-
219-
if attn.norm_added_q is not None:
220-
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
221-
if attn.norm_added_k is not None:
222-
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
223-
224-
# attention
225-
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
226-
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
227-
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
201+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
202+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
203+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
204+
205+
encoder_query = attn.norm_added_q(encoder_query)
206+
encoder_key = attn.norm_added_k(encoder_key)
207+
208+
query = torch.cat([encoder_query, query], dim=1)
209+
key = torch.cat([encoder_key, key], dim=1)
210+
value = torch.cat([encoder_value, value], dim=1)
228211

229212
if image_rotary_emb is not None:
230-
query = apply_rotary_emb(query, image_rotary_emb)
231-
key = apply_rotary_emb(key, image_rotary_emb)
213+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
214+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
232215

233216
hidden_states = dispatch_attention_fn(
234217
query,
@@ -239,23 +222,18 @@ def __call__(
239222
is_causal=False,
240223
backend=self._attention_backend,
241224
)
242-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
225+
hidden_states = hidden_states.flatten(2, 3)
243226
hidden_states = hidden_states.to(query.dtype)
244227

245228
if encoder_hidden_states is not None:
246-
encoder_hidden_states, hidden_states = (
247-
hidden_states[:, : encoder_hidden_states.shape[1]],
248-
hidden_states[:, encoder_hidden_states.shape[1] :],
229+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
230+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
249231
)
250-
251-
# linear proj
252232
hidden_states = attn.to_out[0](hidden_states)
253-
# dropout
254233
hidden_states = attn.to_out[1](hidden_states)
255234
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
256235

257236
# IP-adapter
258-
ip_query = hidden_states_query_proj
259237
ip_attn_output = torch.zeros_like(hidden_states)
260238

261239
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
@@ -264,10 +242,9 @@ def __call__(
264242
ip_key = to_k_ip(current_ip_hidden_states)
265243
ip_value = to_v_ip(current_ip_hidden_states)
266244

267-
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
268-
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
269-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
270-
# TODO: add support for attn.scale when we move to Torch 2.1
245+
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
246+
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
247+
271248
current_ip_hidden_states = dispatch_attention_fn(
272249
ip_query,
273250
ip_key,
@@ -277,9 +254,7 @@ def __call__(
277254
is_causal=False,
278255
backend=self._attention_backend,
279256
)
280-
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
281-
batch_size, -1, attn.heads * head_dim
282-
)
257+
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
283258
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
284259
ip_attn_output += scale * current_ip_hidden_states
285260

@@ -316,6 +291,7 @@ def __init__(
316291
super().__init__()
317292
assert qk_norm == "rms_norm", "Flux uses RMSNorm"
318293

294+
self.head_dim = dim_head
319295
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
320296
self.query_dim = query_dim
321297
self.use_bias = bias

0 commit comments

Comments
 (0)