Skip to content

Commit 429d2cf

Browse files
committed
Remove Flux IP Adapter logic for now
1 parent a1f2ba1 commit 429d2cf

File tree

1 file changed

+0
-135
lines changed

1 file changed

+0
-135
lines changed

src/diffusers/models/transformers/transformer_flux2.py

Lines changed: 0 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -208,145 +208,10 @@ def __call__(
208208
return hidden_states
209209

210210

211-
# TODO: support IP Adapter for Flux.2 as well
212-
class FluxIPAdapterAttnProcessor(torch.nn.Module):
213-
"""Flux Attention processor for IP-Adapter."""
214-
215-
_attention_backend = None
216-
_parallel_config = None
217-
218-
def __init__(
219-
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
220-
):
221-
super().__init__()
222-
223-
if not hasattr(F, "scaled_dot_product_attention"):
224-
raise ImportError(
225-
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
226-
)
227-
228-
self.hidden_size = hidden_size
229-
self.cross_attention_dim = cross_attention_dim
230-
231-
if not isinstance(num_tokens, (tuple, list)):
232-
num_tokens = [num_tokens]
233-
234-
if not isinstance(scale, list):
235-
scale = [scale] * len(num_tokens)
236-
if len(scale) != len(num_tokens):
237-
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
238-
self.scale = scale
239-
240-
self.to_k_ip = nn.ModuleList(
241-
[
242-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
243-
for _ in range(len(num_tokens))
244-
]
245-
)
246-
self.to_v_ip = nn.ModuleList(
247-
[
248-
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
249-
for _ in range(len(num_tokens))
250-
]
251-
)
252-
253-
def __call__(
254-
self,
255-
attn: "Flux2Attention",
256-
hidden_states: torch.Tensor,
257-
encoder_hidden_states: torch.Tensor = None,
258-
attention_mask: Optional[torch.Tensor] = None,
259-
image_rotary_emb: Optional[torch.Tensor] = None,
260-
ip_hidden_states: Optional[List[torch.Tensor]] = None,
261-
ip_adapter_masks: Optional[torch.Tensor] = None,
262-
) -> torch.Tensor:
263-
batch_size = hidden_states.shape[0]
264-
265-
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
266-
attn, hidden_states, encoder_hidden_states
267-
)
268-
269-
query = query.unflatten(-1, (attn.heads, -1))
270-
key = key.unflatten(-1, (attn.heads, -1))
271-
value = value.unflatten(-1, (attn.heads, -1))
272-
273-
query = attn.norm_q(query)
274-
key = attn.norm_k(key)
275-
ip_query = query
276-
277-
if encoder_hidden_states is not None:
278-
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
279-
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
280-
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
281-
282-
encoder_query = attn.norm_added_q(encoder_query)
283-
encoder_key = attn.norm_added_k(encoder_key)
284-
285-
query = torch.cat([encoder_query, query], dim=1)
286-
key = torch.cat([encoder_key, key], dim=1)
287-
value = torch.cat([encoder_value, value], dim=1)
288-
289-
if image_rotary_emb is not None:
290-
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
291-
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
292-
293-
hidden_states = dispatch_attention_fn(
294-
query,
295-
key,
296-
value,
297-
attn_mask=attention_mask,
298-
dropout_p=0.0,
299-
is_causal=False,
300-
backend=self._attention_backend,
301-
parallel_config=self._parallel_config,
302-
)
303-
hidden_states = hidden_states.flatten(2, 3)
304-
hidden_states = hidden_states.to(query.dtype)
305-
306-
if encoder_hidden_states is not None:
307-
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
308-
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
309-
)
310-
hidden_states = attn.to_out[0](hidden_states)
311-
hidden_states = attn.to_out[1](hidden_states)
312-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
313-
314-
# IP-adapter
315-
ip_attn_output = torch.zeros_like(hidden_states)
316-
317-
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
318-
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
319-
):
320-
ip_key = to_k_ip(current_ip_hidden_states)
321-
ip_value = to_v_ip(current_ip_hidden_states)
322-
323-
ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
324-
ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
325-
326-
current_ip_hidden_states = dispatch_attention_fn(
327-
ip_query,
328-
ip_key,
329-
ip_value,
330-
attn_mask=None,
331-
dropout_p=0.0,
332-
is_causal=False,
333-
backend=self._attention_backend,
334-
parallel_config=self._parallel_config,
335-
)
336-
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
337-
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
338-
ip_attn_output += scale * current_ip_hidden_states
339-
340-
return hidden_states, encoder_hidden_states, ip_attn_output
341-
else:
342-
return hidden_states
343-
344-
345211
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
346212
_default_processor_cls = Flux2AttnProcessor
347213
_available_processors = [
348214
Flux2AttnProcessor,
349-
FluxIPAdapterAttnProcessor,
350215
]
351216

352217
def __init__(

0 commit comments

Comments
 (0)