Skip to content

Commit 83d969e

Browse files
Disable xformers when tracing model.
1 parent 1900e51 commit 83d969e

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

comfy/ldm/modules/attention.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,19 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None):
313313
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
314314
b, _, dim_head = q.shape
315315
dim_head //= heads
316+
317+
disabled_xformers = False
318+
316319
if BROKEN_XFORMERS:
317320
if b * heads > 65535:
318-
return attention_pytorch(q, k, v, heads, mask)
321+
disabled_xformers = True
322+
323+
if not disabled_xformers:
324+
if torch.jit.is_tracing() or torch.jit.is_scripting():
325+
disabled_xformers = True
326+
327+
if disabled_xformers:
328+
return attention_pytorch(q, k, v, heads, mask)
319329

320330
q, k, v = map(
321331
lambda t: t.reshape(b, -1, heads, dim_head),

0 commit comments

Comments
 (0)