Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions CrossAttentionPatch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import math
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.attention import attention_sage, optimized_attention
from .utils import tensor_to_size

class Attn2Replace:
Expand Down Expand Up @@ -64,7 +64,10 @@ def to(self, device, *args, **kwargs):

def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=None, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only', **kwargs):
ipadapter = ipadapter.get_multigpu_clone(q.device)


epsilon = 0.0
if optimized_attention == attention_sage:
epsilon = 1e-5
dtype = q.dtype
cond_or_uncond = extra_options["cond_or_uncond"]
block_type = extra_options["block"][0]
Expand Down Expand Up @@ -105,17 +108,17 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No
if weight_type == "style transfer precise":
if layers == 11 and t_idx == 3:
uncond = cond
cond = cond * 0
cond = cond * epsilon
elif layers == 16 and (t_idx == 4 or t_idx == 5):
uncond = cond
cond = cond * 0
cond = cond * epsilon
elif weight_type == "composition precise":
if layers == 11 and t_idx != 3:
uncond = cond
cond = cond * 0
cond = cond * epsilon
elif layers == 16 and (t_idx != 4 and t_idx != 5):
uncond = cond
cond = cond * 0
cond = cond * epsilon

weight = weight[t_idx]

Expand Down Expand Up @@ -170,7 +173,7 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
elif weight == 0:
return 0

k_cond = ipadapter.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1)
k_uncond = ipadapter.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1)
v_cond = ipadapter.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1)
Expand Down