Skip to content

Commit 7891743

Browse files
littskhanwen-sunStrivin0311WT1Wlijinnn
committed
[WIP] Support CatGQA (#256)
* add pack_gqa template for bwd * support pack_gqa for tile_scheduler * finish mainloop * support bwd_epilogue for pack_gqa * fix test_flex_flash_attn * format packgqa for ffa bwd * add packgqa_swapab bench * format * removed useless argument from exps/grpcoll test * updated generate_inst script to use re to extract the kernel function signatures * speed up magi_attn_comm building by skipping building when instantiations have not changed * added native_grpcoll_split_alignment envvar with checking * impl _preprocess_args_for_split_alignment and added pragma: no cover for all __repr__ * minor fixed logging and repr * minor fixed comments * minor fixed comments * minor fixed comments * refactored buffer to extract the common output view out * refactored buffer to add split alignment to view * refactored buffer to add split alignment to view for lse * implemented test intranode with split alignment * minor updated test intranode * updated test_intranode_grpcoll * minor fixed comments * added temp debug code to let benchmark meet the split alignment * raised up kNumTMABytesPerWarp to 216KB to support larger token * implemented split_alignment for internode * fixed a bytes count bug for internode; forbid pass_padded_out_buffer with split_alignment > 1 * updated benchmark settings * Support per split token in static solver (#228) * Modify the static solver so that each segment of input_split_size is divisible by the same number * modify chunk logic in static solver * Dyn solver split alignment (#230) * add merge_with_split_alignment method in AttnRanges * support split alignment in dynamic solver * Relax INT_MAX buffer size limit for internode (#229) * relaxed the buffer size up to INT_MAX limit for internode * tested over INT_MAX buffer size in exp/grpcoll tests * minor fixed * added docstring for config funcs * added minimium num bytes check for native grpcoll * fixed tma bytes and num warps for internode cache notify kernel * raised up default num_rdma_bytes * further fixed internode cache notify kernel for group reduce * removed the temp debug code to make benchmark mask split-aligned * add dynamic_solver_vis (#231) * Dynamic split alignment (#233) * added num_heads_q,kv,group to comm meta for dynamic solver; added seperate split alignment for kv/qo * added num_heads_q/kv to comm meta for dynamic solver * supported split alignment varying from dtype * added native_grpcoll_split_alignment to test_pipeline/test_pipeline_sdpa * tested through dynamic split alignment for pipeline ut; added world size offset for seed * added some comments * added MAGI_ATTENTION_NATIVE_GRPCOLL_SPLIT_ALIGNMENT to docs * updated the docs for MAGI_ATTENTION_AUTO_RANGE_MERGE * build cp-bench docker image * Update API for num_heads and head_dim (#236) * updated and polished api for required num_heads_q, num_heads_kv, head_dim * adjusted the calls in ut for updated APIs * adjusted the calls in examples for updated APIs * adjusted the calls in exps for updated APIs * adjusted the calls in docs and readme for updated APIs, as well as deleting the magi_attn_varlen_dipatch and magi_attn_flex_dispatch deprecated APIs * minor updated tests/test_api/test_interface.py * minor updated benchmark dockerfile * Support auto split alignment (#241) * added head dim to comm meta * supported auto split alignment w/o varying from dtypes * minor updated repr and utils * added strategy for calc_split_alignment * hotfix switch envvars in bench * speed up epilogue and aovid read uninitialized memory * polish code * enhance packgqa * code refactor and bug fix * code refactor and bug fix * code refactor and bug fix * add ut * fixed get_a2av_perm_idx kernel * supported get_a2av_perm_idx for 32 nodes * polish code * polish code * fix ci * remove cuda12 build * remove cuda12 build --------- Co-authored-by: shw <shw20010329@163.com> Co-authored-by: Strivin0311 <hyp@smail.nju.edu.cn> Co-authored-by: WT1W <100120067+WT1W@users.noreply.github.com> Co-authored-by: lijinnn <31332658+lijinnn@users.noreply.github.com> Co-authored-by: Big-TRex <1910960034@qq.com>
1 parent 50947f6 commit 7891743

28 files changed

+1653
-737
lines changed
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
# Copyright (c) 2025-2026 SandAI. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from datetime import datetime
17+
18+
import torch
19+
from baselines.attn_impl import ffa_func
20+
from baselines.utils import seed_everything
21+
from einops import rearrange
22+
23+
from magi_attention.benchmarking import Benchmark, do_bench_flops, perf_report
24+
from magi_attention.utils.sparse_utils import (
25+
flatten_block_mask_to_kv_shape,
26+
generate_block_sparse_pattern,
27+
generate_ranges_from_block_mask,
28+
)
29+
30+
impls = ["ffa_packgqa_swapab", "ffa"]
31+
32+
# actual seqlen
33+
seqlens = [8192 * 8]
34+
35+
# current block sparse attention always has low sparsity
36+
sparsity_ratio = [0.05]
37+
# ss = [k * 1024 for k in [4, 96, 128]]
38+
ds = [128]
39+
wds = ["fwd"]
40+
attn_modes = ["GQA"] # MHA, GQA
41+
nhqs = [64]
42+
num_groups = [8, 16]
43+
# small K block
44+
# q_block_sizes = [64, 64, 64, 64, 64]
45+
# k_block_sizes = [64, 32, 16, 8, 1]
46+
# small Q block
47+
# q_block_sizes = [64, 32, 16, 8]
48+
# k_block_sizes = [64, 64, 64, 64]
49+
q_block_sizes = [1]
50+
k_block_sizes = [64]
51+
# large Q block and K block
52+
# q_block_sizes = [64, 128]
53+
# k_block_sizes = [64, 128]
54+
55+
assert len(q_block_sizes) == len(k_block_sizes)
56+
57+
b = 1
58+
59+
dtype = torch.bfloat16
60+
61+
bias = None
62+
softmax_scale = None
63+
dropout_p = 0.0
64+
return_attn_probs = False
65+
66+
quantiles = [0.5, 0.2, 0.8]
67+
68+
69+
attn_flops_configs = [
70+
Benchmark(
71+
x_names=["sparsity_ratio"], # Argument names to use as an x-axis for the plot.
72+
x_vals=sparsity_ratio, # Different possible values for `x_name`.
73+
x_log=False, # x axis is logarithmic.
74+
line_arg="attn_impl", # Argument name whose value corresponds to a different line in the plot.
75+
line_vals=impls, # Possible values for `line_arg`.
76+
line_names=impls, # Label name for the lines.
77+
styles=[ # Line styles.
78+
("green", "--"),
79+
("orange", "--"),
80+
("steelblue", "--"),
81+
("red", "-"),
82+
],
83+
ylabel={ # Label name for the y-axis.
84+
"flops": "Throughout (TFLOPs/s)",
85+
"mem": "Peak Memory (GB)",
86+
},
87+
plot_name=(
88+
f"block sparse attn-{wd} attn_mode-{attn_mode} "
89+
f"{'n_head-' + str(nhq) if attn_mode == 'MHA' else f'n_head-{nhq}:{nhq // num_group}'}\n"
90+
f"block_size-{q_block_size}:{k_block_size} seq_len {seqlen}"
91+
),
92+
# Name for the plot. Used also as a file name for saving the plot.
93+
args={ # Values for function arguments not in `x_names` and `y_name`.
94+
"hd": hd,
95+
"wd": wd,
96+
"q_block_size": q_block_size,
97+
"k_block_size": k_block_size,
98+
"seqlen": seqlen,
99+
"num_group": num_group,
100+
"attn_mode": attn_mode,
101+
"nhq": nhq,
102+
},
103+
)
104+
for hd in ds
105+
for wd in wds
106+
for q_block_size, k_block_size in zip(q_block_sizes, k_block_sizes)
107+
for seqlen in seqlens
108+
for num_group in num_groups
109+
for attn_mode in attn_modes
110+
for nhq in nhqs
111+
]
112+
113+
seed_everything()
114+
115+
116+
@perf_report(attn_flops_configs)
117+
def sparse_attn_benchmark(
118+
sparsity_ratio,
119+
hd,
120+
wd,
121+
q_block_size,
122+
k_block_size,
123+
seqlen,
124+
num_group,
125+
attn_mode,
126+
nhq,
127+
attn_impl,
128+
):
129+
assert b == 1, "for now, we only supports b=1 for ffa"
130+
is_attn_impl_support_this_mask = True
131+
already_known_oom_before_run = False
132+
133+
# --------- prepare arguments --------- #
134+
135+
device = torch.cuda.current_device()
136+
orig_seq_len_q = orig_seq_len_k = seqlen # fi square mask where sq == sk
137+
block_m = q_block_size
138+
block_n = k_block_size
139+
140+
num_q_blocks_orig = orig_seq_len_q // block_m
141+
num_kv_blocks_orig = orig_seq_len_k // block_n
142+
orig_head = nhq
143+
if attn_mode == "MHA":
144+
nhk = nhq
145+
elif attn_mode == "GQA":
146+
nhk = nhq // num_group
147+
148+
# prepare q, k ranges and calculate attn_flops
149+
# for now, we only do bench for block sparse mask.
150+
# block_mask, scores = generate_global_block_sparse_pattern(
151+
# orig_head, num_q_blocks_orig, num_kv_blocks_orig, sparsity_ratio, device="cuda"
152+
# )
153+
154+
block_mask, scores = generate_block_sparse_pattern(
155+
num_q_heads=nhq,
156+
num_kv_heads=nhk,
157+
num_q_blocks=num_q_blocks_orig,
158+
num_kv_blocks=num_kv_blocks_orig,
159+
sparsity=sparsity_ratio,
160+
device="cuda",
161+
)
162+
163+
attn_flops = 4 * orig_seq_len_q * orig_seq_len_k * orig_head * hd * sparsity_ratio
164+
# --------- prepare data --------- #
165+
# flash style shape: (b,s,h,d)
166+
q = torch.randn(
167+
b, orig_seq_len_q, nhq, hd, device=device, dtype=dtype, requires_grad=False
168+
)
169+
k = torch.randn(
170+
b, orig_seq_len_k, nhk, hd, device=device, dtype=dtype, requires_grad=False
171+
)
172+
v = torch.randn(
173+
b, orig_seq_len_k, nhk, hd, device=device, dtype=dtype, requires_grad=False
174+
)
175+
176+
# ffa style shape: (t,h,d)
177+
if attn_impl in ("ffa_packgqa_swapab", "ffa"):
178+
h1 = nhk
179+
q = rearrange(q, "b s (h1 h2) d -> (b h1 s) h2 d", h1=h1)
180+
k = rearrange(k, "b s h d -> (b h s) 1 d")
181+
v = rearrange(v, "b s h d -> (b h s) 1 d")
182+
183+
if attn_impl in ("sdpa", "vsa", "vsa_triton", "flashinfer", "flex"):
184+
q = rearrange(q, "b s h d -> b h s d")
185+
k = rearrange(k, "b s h d -> b h s d")
186+
v = rearrange(v, "b s h d -> b h s d")
187+
188+
# --------- prepare grads --------- #
189+
190+
if wd == "bwd":
191+
attn_flops = attn_flops * 2.5
192+
do = torch.randn_like(q)
193+
# require grads
194+
[x.requires_grad_(True) for x in [q, k, v, do]]
195+
196+
# --------- prepare func --------- #
197+
# is_attn_impl_support_this_mask = block_sparse_available(
198+
# attn_impl, nhq, nhk, q_block_size, k_block_size, wd
199+
# )
200+
is_attn_impl_support_this_mask = True
201+
if is_attn_impl_support_this_mask:
202+
if attn_impl == "ffa_packgqa_swapab":
203+
# flatten headdim for ffa cause
204+
flat_block_sparse_mask = flatten_block_mask_to_kv_shape(block_mask)
205+
q_ranges, k_ranges = generate_ranges_from_block_mask(
206+
flat_block_sparse_mask, block_m, block_n
207+
)
208+
209+
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
210+
211+
# TODO: we need to optimize choose_ref_block.
212+
# You'd better set ref_blocks manually now
213+
# ref_block_size = choose_ref_block((q_block_size, k_block_size))
214+
ref_block_size = (64, 64)
215+
216+
def fn():
217+
return ffa_func(
218+
q,
219+
k,
220+
v,
221+
q_ranges=q_ranges,
222+
k_ranges=k_ranges,
223+
attn_type_map=attn_type_map,
224+
auto_range_merge=True, # we should enable auto_range_merge for block sparse mask.
225+
ref_block_size=ref_block_size,
226+
pack_gqa=True,
227+
swap_ab=True,
228+
disable_fwd_atomic_reduction=True,
229+
)
230+
231+
if wd == "bwd":
232+
try:
233+
o, *rest = fn()
234+
except Exception as e:
235+
if "CUDA out of memory" not in str(e):
236+
print(
237+
f"Error occured before running {attn_impl} with "
238+
f"{q_block_size=}, {k_block_size=} "
239+
f"when {seqlen=}, {hd=} during {wd}: {e=}"
240+
)
241+
raise e
242+
already_known_oom_before_run = True
243+
244+
def fn():
245+
o.backward(do, retain_graph=True)
246+
247+
elif attn_impl == "ffa":
248+
# flatten headdim for ffa cause
249+
flat_block_sparse_mask = flatten_block_mask_to_kv_shape(block_mask)
250+
q_ranges, k_ranges = generate_ranges_from_block_mask(
251+
flat_block_sparse_mask, block_m, block_n
252+
)
253+
254+
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
255+
256+
# ref_block_size = choose_ref_block((q_block_size, k_block_size))
257+
ref_block_size = (64, 64)
258+
259+
def fn():
260+
return ffa_func(
261+
q,
262+
k,
263+
v,
264+
q_ranges=q_ranges,
265+
k_ranges=k_ranges,
266+
attn_type_map=attn_type_map,
267+
auto_range_merge=True, # we should enable auto_range_merge for block sparse mask.
268+
ref_block_size=ref_block_size,
269+
pack_gqa=False,
270+
disable_fwd_atomic_reduction=True,
271+
)
272+
273+
if wd == "bwd":
274+
try:
275+
o, *rest = fn()
276+
except Exception as e:
277+
if "CUDA out of memory" not in str(e):
278+
print(
279+
f"Error occured before running {attn_impl} with "
280+
f"{q_block_size=}, {k_block_size=} "
281+
f"when {seqlen=}, {hd=} during {wd}: {e=}"
282+
)
283+
raise e
284+
already_known_oom_before_run = True
285+
286+
def fn():
287+
o.backward(do, retain_graph=True)
288+
289+
# --------- try do the bench --------- #
290+
if is_attn_impl_support_this_mask:
291+
if already_known_oom_before_run:
292+
# -1 indicates oom
293+
perf_dict = {
294+
"flops": [-1, -1, -1],
295+
"mem": [-1, -1, -1],
296+
}
297+
else:
298+
try:
299+
# disable mem test to only test flops for now
300+
perf_dict = do_bench_flops(
301+
fn,
302+
quantiles=quantiles,
303+
mem_record_mode="peak",
304+
)
305+
306+
# --------- process report --------- #
307+
308+
# post process the perf_dict
309+
def ms_to_tflops(ms: float) -> float:
310+
return attn_flops / ms * 1e-9
311+
312+
perf_dict["flops"] = list(map(ms_to_tflops, perf_dict["flops"]))
313+
314+
# disable mem test
315+
def gb(m):
316+
return m / 1024**3
317+
318+
# perf_dict["mem"] = list(map(gb, perf_dict["mem"]))
319+
except Exception as e:
320+
if "CUDA out of memory" not in str(e):
321+
print(
322+
f"Error occured before running {attn_impl} with "
323+
f"{q_block_size=}, {k_block_size=} "
324+
f"when {seqlen=}, {hd=} during {wd}: {e=}"
325+
)
326+
perf_dict = {
327+
"flops": [-2, -2, -2],
328+
"mem": [-2, -2, -2],
329+
}
330+
# raise e
331+
# -1 indicates oom
332+
perf_dict = {
333+
"flops": [-1, -1, -1],
334+
"mem": [-1, -1, -1],
335+
}
336+
print(
337+
f"Error occured before running {attn_impl} with {q_block_size=}, {k_block_size=} "
338+
f"when {seqlen=}, {hd=} during {wd}: {e=}"
339+
)
340+
else:
341+
# -2 indicates not support
342+
perf_dict = {
343+
"flops": [-2, -2, -2],
344+
"mem": [-2, -2, -2],
345+
}
346+
347+
return perf_dict
348+
349+
350+
if __name__ == "__main__":
351+
script_dir = os.path.dirname(os.path.abspath(__file__))
352+
current_time = datetime.strftime(datetime.now(), "%Y-%m-%d_%H-%M-%S")
353+
out_root = os.path.join(
354+
script_dir, os.path.join("outs", f"bench_attn_{current_time}")
355+
)
356+
357+
sparse_attn_benchmark.run(
358+
print_data=True, print_value_on_bar=False, save_path=out_root
359+
)

0 commit comments

Comments
 (0)