Skip to content

Commit b5d48b7

Browse files
authored
Add KimiDeltaAttention(KDA) (#621)
1 parent 38cc619 commit b5d48b7

File tree

20 files changed

+3332
-12
lines changed

20 files changed

+3332
-12
lines changed

benchmarks/ops/benchmark_kda.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import os
4+
5+
import torch
6+
import triton
7+
from flash_attn import flash_attn_func
8+
from torch.nn import functional as F
9+
10+
from fla.ops.comba import chunk_comba
11+
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
12+
from fla.ops.generalized_delta_rule import chunk_dplr_delta_rule
13+
from fla.ops.kda import chunk_kda
14+
15+
16+
@triton.testing.perf_report(
17+
triton.testing.Benchmark(
18+
# argument names to use as an x-axis for the plot
19+
x_names=['T'],
20+
# different possible values for `x_name`
21+
x_vals=[256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536],
22+
# argument name whose value corresponds to a different line in the plot
23+
line_arg='provider',
24+
# possible values for `line_arg``
25+
line_vals=['gdn', 'comba', 'kda', 'dplr', 'attn'],
26+
# label name for the lines
27+
line_names=['gdn', 'comba', 'kda', 'dplr', 'attn'],
28+
# line styles
29+
styles=[('blue', '-'), ('red', '-.'), ('green', '-'), ('orange', '-.'),
30+
('purple', '-'), ('brown', '-.'), ('pink', '-'), ('gray', '-.')],
31+
ylabel="Execution Time (ms)", # label name for the y-axis
32+
# name for the plot. Used also as a file name for saving the plot.
33+
plot_name="Performance",
34+
args={},
35+
)
36+
)
37+
def benchmark(T, provider):
38+
from fla.utils import device
39+
dtype = torch.bfloat16
40+
B, H, D = 1, 16, 128
41+
42+
# Set TMA environment variable based on provider
43+
original_tma_env = os.environ.get('FLA_USE_TMA', '0')
44+
45+
if provider.endswith('_no_tma'):
46+
os.environ['FLA_USE_TMA'] = '0'
47+
provider_base = provider.replace('_no_tma', '')
48+
else:
49+
os.environ['FLA_USE_TMA'] = '1'
50+
provider_base = provider
51+
52+
quantiles = [0.5, 0.2, 0.8]
53+
results = 0, 0, 0
54+
55+
do = torch.randn(B, T, H, D, dtype=dtype, device=device)
56+
if provider_base == 'gdn':
57+
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
58+
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
59+
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
60+
g = F.logsigmoid(torch.randn(B, T, H, dtype=dtype, device=device)).requires_grad_(True)
61+
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
62+
results = triton.testing.do_bench(
63+
lambda: chunk_gated_delta_rule(
64+
q=q,
65+
k=k,
66+
v=v,
67+
g=g,
68+
beta=beta,
69+
use_qk_l2norm_in_kernel=True,
70+
)[0].backward(do),
71+
quantiles=quantiles
72+
)
73+
elif provider_base == 'attn':
74+
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
75+
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
76+
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
77+
results = triton.testing.do_bench(
78+
lambda: flash_attn_func(
79+
q=q,
80+
k=k,
81+
v=v,
82+
).backward(do),
83+
quantiles=quantiles
84+
)
85+
elif provider_base == 'comba':
86+
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
87+
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
88+
p = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
89+
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
90+
g = F.logsigmoid(torch.randn(B, T, H, dtype=torch.float, device=device)).requires_grad_(True)
91+
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
92+
results = triton.testing.do_bench(
93+
lambda: chunk_comba(
94+
q=q,
95+
k=k,
96+
p=p,
97+
v=v,
98+
g=g,
99+
beta=beta,
100+
use_qk_l2norm_in_kernel=True,
101+
)[0].backward(do),
102+
quantiles=quantiles
103+
)
104+
elif provider_base == 'kda':
105+
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
106+
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
107+
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
108+
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
109+
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
110+
results = triton.testing.do_bench(
111+
lambda: chunk_kda(
112+
q=q,
113+
k=k,
114+
v=v,
115+
g=g,
116+
beta=beta,
117+
use_qk_l2norm_in_kernel=True,
118+
)[0].backward(do),
119+
quantiles=quantiles
120+
)
121+
elif provider_base == 'dplr':
122+
q = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
123+
k = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
124+
a = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
125+
b = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
126+
v = torch.randn(B, T, H, D, dtype=dtype, device=device).requires_grad_(True)
127+
g = F.logsigmoid(torch.randn(B, T, H, D, dtype=dtype, device=device)).requires_grad_(True)
128+
beta = torch.randn(B, T, H, dtype=dtype, device=device).sigmoid().requires_grad_(True)
129+
results = triton.testing.do_bench(
130+
lambda: chunk_dplr_delta_rule(
131+
q=q,
132+
k=k,
133+
v=v,
134+
a=a,
135+
b=b,
136+
gk=g,
137+
)[0].backward(do),
138+
quantiles=quantiles
139+
)
140+
141+
# Restore original TMA environment variable
142+
os.environ['FLA_USE_TMA'] = original_tma_env
143+
return results
144+
145+
146+
if __name__ == '__main__':
147+
benchmark.run(print_data=True)

fla/layers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .gsa import GatedSlotAttention
1616
from .hgrn import HGRNAttention
1717
from .hgrn2 import HGRN2Attention
18+
from .kda import KimiDeltaAttention
1819
from .lightnet import LightNetAttention
1920
from .linear_attn import LinearAttention
2021
from .log_linear_mamba2 import LogLinearMamba2
@@ -45,6 +46,7 @@
4546
'GatedSlotAttention',
4647
'HGRNAttention',
4748
'HGRN2Attention',
49+
'KimiDeltaAttention',
4850
'LightNetAttention',
4951
'LinearAttention',
5052
'LogLinearMamba2',

fla/layers/gated_deltanet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def forward(
281281
initial_state=recurrent_state,
282282
output_final_state=use_cache,
283283
cu_seqlens=cu_seqlens,
284-
use_qk_l2norm_in_kernel=True
284+
use_qk_l2norm_in_kernel=True,
285285
)
286286
elif mode == 'fused_recurrent':
287287
o, recurrent_state = fused_recurrent_gated_delta_rule(
@@ -293,7 +293,7 @@ def forward(
293293
initial_state=recurrent_state,
294294
output_final_state=use_cache,
295295
cu_seqlens=cu_seqlens,
296-
use_qk_l2norm_in_kernel=True
296+
use_qk_l2norm_in_kernel=True,
297297
)
298298
else:
299299
raise NotImplementedError(f"Not supported mode `{mode}`.")

0 commit comments

Comments
 (0)