Skip to content

Commit 9c1ed58

Browse files
proper fix for sag.
1 parent 8b90e50 commit 9c1ed58

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

comfy_extras/nodes_sag.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,23 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
5858
# Global Average Pool
5959
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
6060

61-
f = float(lh) / float(lw)
62-
fh = f ** 0.5
63-
fw = (1/f) ** 0.5
64-
S = mask.size(1) ** 0.5
65-
w = int(0.5 + S * fw)
66-
h = int(0.5 + S * fh)
61+
total = mask.shape[-1]
62+
x = round(math.sqrt((lh / lw) * total))
63+
xx = None
64+
for i in range(0, math.floor(math.sqrt(total) / 2)):
65+
for j in [(x + i), max(1, x - i)]:
66+
if total % j == 0:
67+
xx = j
68+
break
69+
if xx is not None:
70+
break
71+
72+
x = xx
73+
y = total // x
6774

6875
# Reshape
6976
mask = (
70-
mask.reshape(b, h, w)
77+
mask.reshape(b, x, y)
7178
.unsqueeze(1)
7279
.type(attn.dtype)
7380
)

0 commit comments

Comments
 (0)