Skip to content

Commit 45033ab

Browse files
authored
1. fix attention causal; rename to align the naming for wideep path of sglang; update readme a little bit to reflect; update attn collector to collect qheads 1,2 (#101)
1 parent e9f388e commit 45033ab

File tree

10 files changed

+29
-47
lines changed

10 files changed

+29
-47
lines changed

collector/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ export OUTPUT_PATH=/path/to/output
8585

8686
# Run DeepSeek-specific attention collector
8787
SGLANG_LOAD_FORMAT=dummy SGLANG_TEST_NUM_LAYERS=2 \
88-
python collect_attn.py --model_path $MODEL_PATH --output_path $OUTPUT_PATH
88+
python collect_wideep_attn.py --model_path $MODEL_PATH --output_path $OUTPUT_PATH
8989

9090
# Run DeepSeek MLP collector
91-
python collect_mlp.py --model_path $MODEL_PATH --output_path $OUTPUT_PATH
91+
python collect_wideep_mlp.py --model_path $MODEL_PATH --output_path $OUTPUT_PATH
9292

9393
# Run DeepSeek DeepEP MoE collector (requires 2+ GPUs)
94-
python collect_deepep_moe.py --model_path $MODEL_PATH --output_path $OUTPUT_PATH \
94+
python collect_wideep_deepep_moe.py --model_path $MODEL_PATH --output_path $OUTPUT_PATH \
9595
--tp_size 2 --ep_size 2 --num_experts 256
9696
```
9797
See `sglang/README.md` for detailed documentation on these collectors.

collector/collect.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,14 @@ def collect_sglang(num_processes: int, ops: list[str] | None = None):
403403
{
404404
"name": "sglang",
405405
"type": "attention_context",
406-
"module": "collector.sglang.collect_normal_attn",
406+
"module": "collector.sglang.collect_attn",
407407
"get_func": "get_context_attention_test_cases",
408408
"run_func": "run_attention_torch",
409409
},
410410
{
411411
"name": "sglang",
412412
"type": "attention_generation",
413-
"module": "collector.sglang.collect_normal_attn",
413+
"module": "collector.sglang.collect_attn",
414414
"get_func": "get_generation_attention_test_cases",
415415
"run_func": "run_attention_torch",
416416
},

collector/deep_collector/extract_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -738,13 +738,13 @@ def main():
738738
)
739739
parser.add_argument(
740740
"--output-normal",
741-
default="./deepep_normal_perf.txt",
742-
help="normal output TXT file path (default: ./deepep_normal_perf.txt)",
741+
default="./wideep_deepep_normal_perf.txt",
742+
help="normal output TXT file path (default: ./wideep_deepep_normal_perf.txt)",
743743
)
744744
parser.add_argument(
745745
"--output-ll",
746-
default="./deepep_ll_perf.txt",
747-
help="ll output TXT file path (default: ./deepep_ll_perf.txt)",
746+
default="./wideep_deepep_ll_perf.txt",
747+
help="ll output TXT file path (default: ./wideep_deepep_ll_perf.txt)",
748748
)
749749
args = parser.parse_args()
750750

collector/sglang/README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ The collected performance data can be used for performance modeling, scheduling
1414

1515
## Overview
1616

17-
- **collect_deepseek_attn.py**: Collects performance data for DeepSeek Attention (MLA) operators
18-
- **collect_deepep_moe.py**: Collects performance data for DeepSeek MoE operators
19-
- **collect_deepseek_mlp.py**: Collects performance data for Shared Expert (MLP) operators
17+
- **collect_wideep_attn.py**: Collects performance data for DeepSeek Attention (MLA) operators
18+
- **collect_wideep_deepep_moe.py**: Collects performance data for DeepSeek MoE operators
19+
- **collect_wideep_mlp.py**: Collects performance data for Shared Expert (MLP) operators
2020

2121
## Requirements
2222

@@ -34,7 +34,7 @@ output_path = "/aiconfigurator/src/aiconfigurator/systems/data/h100_sxm/sglang/0
3434
```
3535

3636

37-
## 1. Attention Operator Collection (collect_deepseek_attn.py)
37+
## 1. Attention Operator Collection (collect_wideep_attn.py)
3838

3939
### Features
4040
- Tests different attention backends (flashinfer, fa3)
@@ -47,7 +47,7 @@ output_path = "/aiconfigurator/src/aiconfigurator/systems/data/h100_sxm/sglang/0
4747
#### Basic Run with dummy weight
4848
```bash
4949
export DEEPSEEK_MODEL_PATH=/path/to/deepseek-v3
50-
python collect_deepseek_attn.py
50+
python collect_wideep_attn.py
5151
```
5252
#### Environment Variables
5353
- `DEEPSEEK_MODEL_PATH`: Path to DeepSeek model
@@ -72,7 +72,7 @@ Output format:
7272
framework,version,device,op_name,kernel_source,mla_dtype,kv_cache_dtype,num_heads,batch_size,isl,tp_size,step,latency
7373
```
7474

75-
## 2. MoE Operator Collection (collect_deepep_moe.py)
75+
## 2. MoE Operator Collection (collect_wideep_deepep_moe.py)
7676

7777
### Features
7878
- Tests DeepEP MoE operator performance
@@ -85,7 +85,7 @@ framework,version,device,op_name,kernel_source,mla_dtype,kv_cache_dtype,num_head
8585
#### Basic Run
8686
```bash
8787
export DEEPSEEK_MODEL_PATH=/path/to/deepseek-v3
88-
python collect_deepep_moe.py
88+
python collect_wideep_deepep_moe.py
8989
```
9090

9191
#### Environment Variables
@@ -139,7 +139,7 @@ Output format:
139139
framework,version,device,op_name,kernel_source,moe_dtype,num_tokens,hidden_size,inter_size,topk,num_experts,moe_tp_size,moe_ep_size,distribution,latency
140140
```
141141

142-
## 3. MLP Operator Collection (collect_deepseek_mlp.py)
142+
## 3. MLP Operator Collection (collect_wideep_mlp.py)
143143

144144
### Features
145145
- Tests DeepSeek V2/V3 MLP operator performance
@@ -151,7 +151,7 @@ framework,version,device,op_name,kernel_source,moe_dtype,num_tokens,hidden_size,
151151
#### Basic Run
152152
```bash
153153
export DEEPSEEK_MODEL_PATH=/path/to/deepseek-v3
154-
python collect_deepseek_mlp.py
154+
python collect_wideep_mlp.py
155155
```
156156

157157
#### Environment Variables
Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def get_context_attention_test_cases():
2323
test_cases = []
2424
b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
2525
s_list = [16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 10240, 12288, 16384, 262144]
26-
n_list = [4, 8, 12, 16, 24, 32, 40, 48, 64, 96]
26+
n_list = [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 64, 96]
2727
n_kv_list = [0, 1, 2, 4, 8]
2828
for n in sorted(n_list, reverse=True):
2929
for s in sorted(s_list, reverse=True):
@@ -74,8 +74,8 @@ def get_generation_attention_test_cases():
7474
# the i-th token to record. 1 for context phase. mapping to osl definition
7575
s_list = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
7676
# full n {4, 5, 7, 8, 9, 10, 12, 14, 16, 18, 20, 24, 28, 32, 36, 40, 48, 56, 72, 96}
77-
n_list = [4, 8, 12, 16, 24, 32, 40, 48, 64]
78-
n_list_xqa = [4, 8, 16, 32, 64, 96, 128]
77+
n_list = [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 64]
78+
n_list_xqa = [1, 2, 4, 8, 16, 32, 64, 96, 128]
7979
n_kv_list = [1, 2, 4, 8]
8080

8181
# MHA
@@ -193,14 +193,7 @@ def run_attention_torch(
193193
k_cache, v_cache = [x.detach().to(kvtype).requires_grad_() for x in [k_cache, v_cache]]
194194
k, v, cache_seqlens = None, None, None
195195

196-
def float16attn_fp8kvcache(q, k_cache, v_cache, k, v, **kwargs):
197-
k_cache = k_cache.to(torch.bfloat16)
198-
v_cache = v_cache.to(torch.bfloat16)
199-
k = None if k is None else k.to(torch.bfloat16)
200-
v = None if v is None else v.to(torch.bfloat16)
201-
flash_attn_func_v3(q, k_cache, v_cache, k, v, **kwargs)
202-
203-
if use_fp8_context_fmha:
196+
if use_fp8_context_fmha or use_fp8_kv_cache:
204197
q = q.to(kvtype)
205198
m1 = time_fwd(
206199
flash_attn_func_v3,
@@ -210,19 +203,7 @@ def float16attn_fp8kvcache(q, k_cache, v_cache, k, v, **kwargs):
210203
k,
211204
v,
212205
cache_seqlens=cache_seqlens,
213-
repeats=10,
214-
verbose=True,
215-
desc="Fav3",
216-
)
217-
elif use_fp8_kv_cache:
218-
m1 = time_fwd(
219-
float16attn_fp8kvcache,
220-
q,
221-
k_cache,
222-
v_cache,
223-
k,
224-
v,
225-
cache_seqlens=cache_seqlens,
206+
causal=True,
226207
repeats=10,
227208
verbose=True,
228209
desc="Fav3",
@@ -236,6 +217,7 @@ def float16attn_fp8kvcache(q, k_cache, v_cache, k, v, **kwargs):
236217
k,
237218
v,
238219
cache_seqlens=cache_seqlens,
220+
causal=True,
239221
repeats=10,
240222
verbose=True,
241223
desc="Fav3",
File renamed without changes.

collector/trtllm/collect_attn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def get_context_attention_test_cases():
293293
16384,
294294
262144,
295295
]
296-
n_list = [4, 8, 12, 16, 24, 32, 40, 48, 64, 96]
296+
n_list = [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 64, 96]
297297
n_kv_list = [0, 1, 2, 4, 8]
298298
head_dim = [64, 128]
299299

@@ -507,8 +507,8 @@ def get_generation_attention_test_cases():
507507
65536,
508508
131072,
509509
]
510-
n_list = [4, 8, 12, 16, 24, 32, 40, 48, 64]
511-
n_list_xqa = [4, 8, 16, 32, 64, 96, 128]
510+
n_list = [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 64]
511+
n_list_xqa = [1, 2, 4, 8, 16, 32, 64, 96, 128]
512512
n_kv_list = [1, 2, 4, 8]
513513
head_dim = [64, 128]
514514

collector/vllm/collect_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def get_context_attention_test_cases(if_unit_test=False):
304304
16384,
305305
262144,
306306
]
307-
n_list = [4, 8, 12, 16, 24, 32, 40, 48, 64]
307+
n_list = [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 64]
308308
n_kv_list = [0, 1, 2, 4, 8]
309309
# n_kv_list = [64]
310310
else:
@@ -360,7 +360,7 @@ def get_generation_attention_test_cases():
360360

361361
b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
362362
# b_list_xqa = [1,2,4,8,16,32,64,128,256,512,1024,2048]
363-
n_list = [4, 8, 12, 16, 24, 32, 40, 48, 64]
363+
n_list = [1, 2, 4, 8, 12, 16, 24, 32, 40, 48, 64]
364364
# n_list_xqa = [4,8,16,32,64,128]
365365
s_list = [
366366
2,

0 commit comments

Comments
 (0)