forked from sorryhyun/comfyui-flex-attention
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
76 lines (54 loc) · 2.54 KB
/
__init__.py
File metadata and controls
76 lines (54 loc) · 2.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import logging
import torch
from comfy.ldm.modules.attention import wrap_attn, attention_pytorch
logger = logging.getLogger(__name__)
_flex_logged = False
def get_flex_func():
from torch.nn.attention.flex_attention import flex_attention
compiled_flex = torch.compile(flex_attention)
@wrap_attn
def attention_flex(q, k, v, heads, mask=None, attn_precision=None,
skip_reshape=False, skip_output_reshape=False, **kwargs):
if mask is not None:
return attention_pytorch(q, k, v, heads, mask=mask,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape, **kwargs)
global _flex_logged
if not _flex_logged:
_flex_logged = True
logger.info(f"[FlexAttn] active: shape={list(q.shape)}, heads={heads}, dtype={q.dtype}")
if skip_reshape:
b, _, _, dim_head = q.shape # (B, H, N, D)
else:
b, _, dim_head = q.shape # (B, N, H*D)
dim_head //= heads
q = q.view(b, -1, heads, dim_head).transpose(1, 2) # -> (B, H, N, D)
k = k.view(b, -1, heads, dim_head).transpose(1, 2)
v = v.view(b, -1, heads, dim_head).transpose(1, 2)
out = compiled_flex(q, k, v)
if skip_output_reshape:
return out # (B, H, N, D)
else:
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) # (B, N, H*D)
return out
return attention_flex
class FlexAttention:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",)}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "attention"
DESCRIPTION = "Patch attention to use PyTorch Flex Attention (torch.compile). No extra dependencies — works on any GPU supported by torch.compile. Falls back to PyTorch SDPA when masks are present."
def patch(self, model):
global _flex_logged
_flex_logged = False
model_clone = model.clone()
new_attention = get_flex_func()
def attention_override_flex(func, *args, **kwargs):
return new_attention.__wrapped__(*args, **kwargs)
model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_flex
logger.info("[FlexAttn] Patched attention with compiled flex_attention")
return (model_clone,)
NODE_CLASS_MAPPINGS = {"FlexAttention": FlexAttention}
NODE_DISPLAY_NAME_MAPPINGS = {"FlexAttention": "Flex Attention (torch.compile)"}