-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsla_node.py
More file actions
330 lines (280 loc) · 11.7 KB
/
sla_node.py
File metadata and controls
330 lines (280 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
"""
ComfyUI Node for SLA (Sparse-Linear Attention)
Implements attention acceleration for diffusion models using SLA
"""
import sys
import torch
import folder_paths
import comfy.model_patcher
import comfy.samplers
from comfy.model_patcher import ModelPatcher
def _setup_spargeattn_compat():
"""
Compatibility shim for SpargeAttn pre-built Windows wheels.
Pre-built wheels from woct0rdho/SpargeAttn use architecture-specific modules
(_qattn_sm80, _qattn_sm89, _qattn_sm90) instead of the generic _qattn module
that source builds create. This shim detects the GPU architecture and
registers the appropriate module as spas_sage_attn._qattn so upstream
SageSLA code works without modification.
Works on:
- Linux/Windows source builds (uses _qattn directly)
- Windows pre-built wheels (falls back to _qattn_smXX)
- RTX 50xx Blackwell (sm120) via sm90 PTX compatibility
- RTX 40xx Ada (sm89)
- RTX 30xx Ampere (sm80/86)
"""
# Check if _qattn already exists (source build)
try:
import spas_sage_attn._qattn
return True
except ModuleNotFoundError:
pass
# Detect GPU architecture
if not torch.cuda.is_available():
return False
try:
props = torch.cuda.get_device_properties(0)
sm = props.major * 10 + props.minor
except Exception:
return False
# Import architecture-specific module and register as _qattn
try:
if sm >= 90:
# Hopper (sm90) and Blackwell (sm120) - use sm90 via PTX compatibility
import spas_sage_attn._qattn_sm90 as qattn
elif sm >= 89:
# Ada Lovelace (sm89)
import spas_sage_attn._qattn_sm89 as qattn
else:
# Ampere (sm80, sm86) and older
import spas_sage_attn._qattn_sm80 as qattn
# Register as _qattn so upstream SageSLA imports work
sys.modules['spas_sage_attn._qattn'] = qattn
return True
except ModuleNotFoundError:
return False
class SLAAttentionNode:
"""
A ComfyUI node that applies SLA (Sparse-Linear Attention) to models.
This accelerates diffusion model inference by combining sparse and linear attention.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"topk": ("FLOAT", {
"default": 0.2,
"min": 0.0,
"max": 1.0,
"step": 0.05,
"display": "slider",
"tooltip": "Sparsity ratio (0.2 = keep 80% of tokens)"
}),
"feature_map": (["softmax", "elu", "relu"], {
"default": "softmax",
"tooltip": "Kernel type for attention computation"
}),
"block_size_q": ("INT", {
"default": 64,
"min": 16,
"max": 256,
"step": 16,
"tooltip": "Block size for query processing"
}),
"block_size_k": ("INT", {
"default": 64,
"min": 16,
"max": 256,
"step": 16,
"tooltip": "Block size for key processing"
}),
"enabled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable/disable SLA attention"
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_sla"
CATEGORY = "model_patches/attention"
DESCRIPTION = "Applies SLA (Sparse-Linear Attention) to accelerate model inference"
def __init__(self):
self.sla_module = None
def apply_sla(self, model, topk, feature_map, block_size_q, block_size_k, enabled):
"""
Apply SLA attention to the model
Args:
model: ComfyUI MODEL object
topk: Sparsity ratio (0.2 = 20% sparsity)
feature_map: Type of feature map ('softmax', 'elu', 'relu')
block_size_q: Block size for query
block_size_k: Block size for key
enabled: Whether to enable SLA
Returns:
Modified MODEL with SLA attention applied
"""
if not enabled:
return (model,)
try:
from sparse_linear_attention import SparseLinearAttention
except ImportError:
print("WARNING: SLA module not found. Please install with: pip install git+https://github.com/thu-ml/SLA.git")
print("Returning unmodified model.")
return (model,)
# Clone the model to avoid modifying the original
cloned_model = model.clone()
# Get the actual model from the patcher
model_object = cloned_model.model
# Patch the attention mechanism
def sla_attention_patch(q, k, v, extra_options):
"""
Custom attention function using SLA
Args:
q: Query tensor [batch, heads, seq_len, head_dim]
k: Key tensor [batch, heads, seq_len, head_dim]
v: Value tensor [batch, heads, seq_len, head_dim]
extra_options: Additional options from ComfyUI
Returns:
Attention output tensor
"""
batch_size, num_heads, seq_len, head_dim = q.shape
# Initialize SLA module if not already done or if parameters changed
sla_key = (head_dim, topk, feature_map, block_size_q, block_size_k)
if not hasattr(self, 'sla_cache') or self.sla_cache.get('key') != sla_key:
self.sla_cache = {
'key': sla_key,
'module': SparseLinearAttention(
head_dim=head_dim,
topk=topk,
feature_map=feature_map,
BLKQ=block_size_q,
BLKK=block_size_k,
).to(q.device).to(q.dtype)
}
sla_attn = self.sla_cache['module']
# Apply SLA attention
# SLA expects input in format [batch, heads, seq_len, head_dim]
output = sla_attn(q, k, v)
return output
# Set the attention patch on the model
cloned_model.set_model_attn1_patch(sla_attention_patch)
print(f"SLA Attention applied with topk={topk}, feature_map={feature_map}, block_size=({block_size_q}, {block_size_k})")
return (cloned_model,)
class SLASageAttentionNode:
"""
A ComfyUI node that applies SageSLA (optimized SLA) to models.
SageSLA is an optimized implementation based on SageAttention framework.
Requires: pip install git+https://github.com/thu-ml/SpargeAttn.git --no-build-isolation
Note: SageSLA only supports fixed block sizes (BLKQ=128, BLKK=64)
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"topk": ("FLOAT", {
"default": 0.2,
"min": 0.0,
"max": 1.0,
"step": 0.05,
"display": "slider",
"tooltip": "Sparsity ratio (0.2 = keep 80% of tokens)"
}),
"feature_map": (["softmax", "elu", "relu"], {
"default": "softmax",
"tooltip": "Kernel type for attention computation"
}),
"enabled": ("BOOLEAN", {
"default": True,
"tooltip": "Enable/disable SageSLA attention"
}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_sage_sla"
CATEGORY = "model_patches/attention"
DESCRIPTION = "Applies SageSLA (optimized SLA) for faster inference. Requires SpargeAttn installation."
def __init__(self):
self.sage_sla_module = None
def apply_sage_sla(self, model, topk, feature_map, enabled):
"""
Apply SageSLA (optimized) attention to the model
Args:
model: ComfyUI MODEL object
topk: Sparsity ratio
feature_map: Type of feature map ('softmax', 'elu', 'relu')
enabled: Whether to enable SageSLA
Returns:
Modified MODEL with SageSLA attention applied
"""
if not enabled:
return (model,)
try:
# Setup compatibility shim for pre-built SpargeAttn wheels
# (patches spas_sage_attn._qattn for architecture-specific modules)
_setup_spargeattn_compat()
# Try to import SageSLA
try:
from SageSLA import SageSparseLinearAttention
use_sage = True
except ImportError:
# Fallback to regular SLA if SageSLA not available
from sparse_linear_attention import SparseLinearAttention
use_sage = False
print("INFO: Using standard SLA (SageSLA not found)")
print("INFO: For optimized SageSLA on Windows, install the pre-built wheel:")
print("INFO: pip install https://github.com/woct0rdho/SpargeAttn/releases/download/v0.1.0-windows.post3/spas_sage_attn-0.1.0%2Bcu128torch2.9.0.post3-cp39-abi3-win_amd64.whl")
except ImportError:
print("WARNING: SLA module not found. Please install with: pip install git+https://github.com/marduk191/SLA.git")
print("Returning unmodified model.")
return (model,)
# Clone the model
cloned_model = model.clone()
# Patch the attention mechanism
def sage_sla_attention_patch(q, k, v, extra_options):
"""Custom attention function using SageSLA"""
batch_size, num_heads, seq_len, head_dim = q.shape
# Initialize SageSLA module
sage_key = (head_dim, topk, feature_map, use_sage)
if not hasattr(self, 'sage_cache') or self.sage_cache.get('key') != sage_key:
if use_sage:
# SageSLA only supports fixed block sizes: BLKQ=128, BLKK=64
self.sage_cache = {
'key': sage_key,
'module': SageSparseLinearAttention(
head_dim=head_dim,
topk=topk,
feature_map=feature_map,
BLKQ=128, # Fixed for SageSLA
BLKK=64, # Fixed for SageSLA
).to(q.device).to(q.dtype)
}
else:
# Standard SLA fallback with configurable block sizes
self.sage_cache = {
'key': sage_key,
'module': SparseLinearAttention(
head_dim=head_dim,
topk=topk,
feature_map=feature_map,
BLKQ=128,
BLKK=64,
).to(q.device).to(q.dtype)
}
sage_attn = self.sage_cache['module']
output = sage_attn(q, k, v)
return output
cloned_model.set_model_attn1_patch(sage_sla_attention_patch)
impl_type = "SageSLA (optimized)" if use_sage else "Standard SLA"
print(f"{impl_type} Attention applied with topk={topk}, feature_map={feature_map}, block_size=(128, 64)")
return (cloned_model,)
# Node mapping for ComfyUI
NODE_CLASS_MAPPINGS = {
"SLAAttention": SLAAttentionNode,
"SLASageAttention": SLASageAttentionNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SLAAttention": "SLA Attention",
"SLASageAttention": "SageSLA Attention (Optimized)",
}