Skip to content

Commit f14af9a

Browse files
committed
fix numerical instability for sageattention
1 parent 9d076a3 commit f14af9a

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

CrossAttentionPatch.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import math
33
import torch.nn.functional as F
4-
from comfy.ldm.modules.attention import optimized_attention
4+
from comfy.ldm.modules.attention import attention_sage, optimized_attention
55
from .utils import tensor_to_size
66

77
class Attn2Replace:
@@ -64,7 +64,10 @@ def to(self, device, *args, **kwargs):
6464

6565
def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=None, weight=1.0, cond=None, cond_alt=None, uncond=None, weight_type="linear", mask=None, sigma_start=0.0, sigma_end=1.0, unfold_batch=False, embeds_scaling='V only', **kwargs):
6666
ipadapter = ipadapter.get_multigpu_clone(q.device)
67-
67+
68+
epsilon = 0.0
69+
if optimized_attention == attention_sage:
70+
epsilon = 1e-5
6871
dtype = q.dtype
6972
cond_or_uncond = extra_options["cond_or_uncond"]
7073
block_type = extra_options["block"][0]
@@ -105,17 +108,17 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No
105108
if weight_type == "style transfer precise":
106109
if layers == 11 and t_idx == 3:
107110
uncond = cond
108-
cond = cond * 0
111+
cond = cond * epsilon
109112
elif layers == 16 and (t_idx == 4 or t_idx == 5):
110113
uncond = cond
111-
cond = cond * 0
114+
cond = cond * epsilon
112115
elif weight_type == "composition precise":
113116
if layers == 11 and t_idx != 3:
114117
uncond = cond
115-
cond = cond * 0
118+
cond = cond * epsilon
116119
elif layers == 16 and (t_idx != 4 and t_idx != 5):
117120
uncond = cond
118-
cond = cond * 0
121+
cond = cond * epsilon
119122

120123
weight = weight[t_idx]
121124

@@ -170,7 +173,7 @@ def ipadapter_attention(out, q, k, v, extra_options, module_key='', ipadapter=No
170173
weight = weight.repeat(len(cond_or_uncond), 1, 1) # repeat for cond and uncond
171174
elif weight == 0:
172175
return 0
173-
176+
174177
k_cond = ipadapter.ip_layers.to_kvs[k_key](cond).repeat(batch_prompt, 1, 1)
175178
k_uncond = ipadapter.ip_layers.to_kvs[k_key](uncond).repeat(batch_prompt, 1, 1)
176179
v_cond = ipadapter.ip_layers.to_kvs[v_key](cond).repeat(batch_prompt, 1, 1)

0 commit comments

Comments
 (0)