Skip to content

Commit 117a23c

Browse files
committed
update kernel
1 parent 77b4a8e commit 117a23c

File tree

23 files changed

+310
-2158
lines changed

23 files changed

+310
-2158
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,50 @@
11
{
2-
"8448": null
2+
"1": {
3+
"BLK_HEADS": 64,
4+
"num_warps": 1
5+
},
6+
"100": {
7+
"BLK_HEADS": 4,
8+
"num_warps": 4
9+
},
10+
"1024": {
11+
"BLK_HEADS": 8,
12+
"num_warps": 1
13+
},
14+
"128": {
15+
"BLK_HEADS": 16,
16+
"num_warps": 4
17+
},
18+
"16": {
19+
"BLK_HEADS": 8,
20+
"num_warps": 2
21+
},
22+
"2048": {
23+
"BLK_HEADS": 16,
24+
"num_warps": 1
25+
},
26+
"256": {
27+
"BLK_HEADS": 32,
28+
"num_warps": 2
29+
},
30+
"32": {
31+
"BLK_HEADS": 8,
32+
"num_warps": 1
33+
},
34+
"4096": {
35+
"BLK_HEADS": 16,
36+
"num_warps": 4
37+
},
38+
"64": {
39+
"BLK_HEADS": 64,
40+
"num_warps": 2
41+
},
42+
"8": {
43+
"BLK_HEADS": 8,
44+
"num_warps": 2
45+
},
46+
"8448": {
47+
"BLK_HEADS": 32,
48+
"num_warps": 4
49+
}
350
}

lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def _linear_attn(
270270
bias=layer_weight.linear_conv1d.mm_param.bias,
271271
query_start_loc=infer_state.b1_cu_q_seq_len,
272272
cache_indices=buffer_idx,
273+
has_initial_state=infer_state.b_ready_cache_len > 0,
273274
conv_states=conv_states.transpose(1, 2),
274275
activation=self.activation,
275276
)

lightllm/models/qwen3next/model.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import torch
23
from typing import Optional
34
from typing_extensions import override
@@ -62,11 +63,11 @@ def _init_mem_manager(self):
6263
start_args: StartArgs = get_env_start_args()
6364

6465
mtp_step = start_args.mtp_step
65-
linear_attn_cache_size = start_args.linear_attn_cache_size
66-
if linear_attn_cache_size is not None:
66+
mamba_cache_size = start_args.mamba_cache_size
67+
if mamba_cache_size is not None:
6768
assert (
68-
linear_attn_cache_size >= start_args.running_max_req_size
69-
), "linear_attn_cache_size must be greater than running_max_req_size"
69+
mamba_cache_size >= start_args.running_max_req_size
70+
), "mamba_cache_size must be greater than running_max_req_size"
7071

7172
self.num_linear_k_heads = self.config["linear_num_key_heads"]
7273
self.num_linear_v_heads = self.config["linear_num_value_heads"]
@@ -78,9 +79,12 @@ def _init_mem_manager(self):
7879
self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads
7980
)
8081

82+
ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32}
83+
assert start_args.mamba_ssm_data_type in ssm_dtype_dict
84+
8185
self.mem_manager = Qwen3NextMemoryManager(
8286
full_attn_cache_size=self.max_total_token_num,
83-
linear_attn_cache_size=linear_attn_cache_size,
87+
linear_attn_cache_size=mamba_cache_size,
8488
dtype=self.data_type,
8589
num_kv_heads=self.num_kv_heads,
8690
head_dim=self.config["head_dim"],
@@ -89,7 +93,7 @@ def _init_mem_manager(self):
8993
full_attention_interval=self.config["full_attention_interval"],
9094
conv_state_dtype=self.data_type,
9195
conv_state_shape=(conv_kernel_size - 1 + mtp_step, conv_dim // self.tp_world_size_),
92-
ssm_state_dtype=self.data_type,
96+
ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type],
9397
ssm_state_shape=(
9498
# mtp_step + 1,
9599
self.num_linear_v_heads // self.tp_world_size_,

lightllm/models/qwen3next/triton_kernel/fla/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@
66
# The original source code was licensed under the MIT license and included
77
# the following copyright notice:
88
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9+
10+
# Adapted from
11+
# https://github.com/vllm-project/vllm

lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py

Lines changed: 0 additions & 225 deletions
This file was deleted.

0 commit comments

Comments
 (0)