Skip to content

Commit d35a2cf

Browse files
author
wangzaijun
committed
fix all
1 parent 809829d commit d35a2cf

File tree

7 files changed

+239
-20
lines changed

7 files changed

+239
-20
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
283283
infer_state.b_ready_cache_len = model_input.b_ready_cache_len
284284
else:
285285
infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len)
286+
else:
287+
if enable_diverse_mode_gqa_decode_fast_kernel():
288+
infer_state.b_shared_seq_len = model_input.b_shared_seq_len
289+
infer_state.b_mark_shared_group = model_input.b_mark_shared_group
286290

287291
infer_state.multimodal_params = model_input.multimodal_params
288292

lightllm/common/basemodel/infer_struct.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def __init__(self):
2424
self.b_req_idx: torch.Tensor = None
2525
self.b_start_loc: torch.Tensor = None
2626
self.b_ready_cache_len: torch.Tensor = None # only for prefill prompt cache used.
27+
28+
self.b_shared_seq_len: torch.Tensor = None # only for diverse kv cache used in decode phase.
29+
self.b_mark_shared_group: torch.Tensor = None # only for diverse kv cache used in decode phase.
30+
2731
self.b_seq_len: torch.Tensor = None
2832
# max_len_in_batch prefill 和 decode 阶段含义不同
2933
# prefill 阶段指每个req 输入token的长度(不包括已经cache的部分)最大值

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def _bind_attention(self):
111111
LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self
112112
)
113113
elif "ppl_int8kv_flashdecoding" in self.mode:
114+
self._token_attention_kernel = partial(
115+
LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding_diverse, self
116+
)
117+
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_ppl_int8kv, self)
118+
self._context_attention_kernel = partial(
119+
LlamaTransformerLayerInfer._context_attention_kernel_ppl_int8kv, self
120+
)
121+
elif "ppl_int8kv_flashdecoding_diverse" in self.mode:
114122
self._token_attention_kernel = partial(
115123
LlamaTransformerLayerInfer._token_decode_attention_ppl_int8kv_flashdecoding, self
116124
)
@@ -784,6 +792,34 @@ def _token_decode_attention_ppl_int8kv_flashdecoding(
784792
alloc_tensor_func=self.alloc_tensor,
785793
)
786794

795+
def _token_decode_attention_ppl_int8kv_flashdecoding_diverse(
796+
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
797+
):
798+
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse import (
799+
token_decode_attention_flash_decoding,
800+
)
801+
802+
cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
803+
cache_k_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
804+
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
805+
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
806+
]
807+
cache_v_scale = infer_state.mem_manager.scale_buffer[self.layer_num_][
808+
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
809+
]
810+
return token_decode_attention_flash_decoding(
811+
q,
812+
infer_state,
813+
self.tp_q_head_num_,
814+
self.head_dim_,
815+
cache_k,
816+
cache_k_scale,
817+
cache_v,
818+
cache_v_scale,
819+
out=out,
820+
alloc_tensor_func=self.alloc_tensor,
821+
)
822+
787823
def _token_decode_attention_ppl_int4kv_flashdecoding(
788824
self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None
789825
):
Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
# 为 diverse mode 定制设计的 int8kv flash decoding attention 实现,可以实现更高效的多样性采样
22
import torch
33
from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops
4+
from lightllm.common.basemodel.infer_struct import InferStateInfo
5+
from .ppl_int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1
6+
from .ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3
7+
from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size
48

59

610
def token_decode_attention_flash_decoding(
711
q,
8-
infer_state,
12+
infer_state: InferStateInfo,
913
q_head_num,
1014
head_dim,
1115
cache_k,
@@ -14,14 +18,21 @@ def token_decode_attention_flash_decoding(
1418
cache_v_scale,
1519
out=None,
1620
alloc_tensor_func=torch.empty,
21+
shared_streams_dict={},
1722
):
23+
if "stream1" not in shared_streams_dict:
24+
shared_streams_dict["stream1"] = torch.cuda.Stream()
25+
if "stream2" not in shared_streams_dict:
26+
shared_streams_dict["stream2"] = torch.cuda.Stream()
27+
28+
stream1 = shared_streams_dict["stream1"]
29+
stream2 = shared_streams_dict["stream2"]
30+
1831
BLOCK_SEQ = 256
1932
batch_size = infer_state.batch_size
2033
max_len_in_batch = infer_state.max_len_in_batch
2134
calcu_shape1 = (batch_size, q_head_num, head_dim)
2235

23-
from .flash_decoding_stage2 import flash_decode_stage2
24-
2536
o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out
2637

2738
mid_o = alloc_tensor_func(
@@ -31,21 +42,54 @@ def token_decode_attention_flash_decoding(
3142
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 2], dtype=q.dtype, device="cuda"
3243
)
3344

34-
light_ops.group8_int8kv_flashdecoding_stage1(
35-
BLOCK_SEQ,
36-
mid_o,
37-
mid_o_logexpsum,
38-
1.0 / (head_dim ** 0.5),
39-
q.view(calcu_shape1),
40-
cache_k,
41-
cache_k_scale,
42-
cache_v,
43-
cache_v_scale,
44-
infer_state.req_manager.req_to_token_indexs,
45-
infer_state.b_req_idx,
46-
infer_state.b_seq_len,
47-
infer_state.max_len_in_batch,
48-
)
45+
current_stream = torch.cuda.current_stream()
4946

50-
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
47+
stream1.wait_stream(current_stream)
48+
with torch.cuda.stream(stream1):
49+
flash_decode_stage1(
50+
q=q.view(calcu_shape1),
51+
k=cache_k,
52+
k_scale=cache_k_scale,
53+
v=cache_v,
54+
v_scale=cache_v_scale,
55+
Req_to_tokens=infer_state.req_manager.req_to_token_indexs,
56+
B_req_idx=infer_state.b_req_idx,
57+
b_shared_seq_len=infer_state.b_shared_seq_len,
58+
b_mark_shared_group=infer_state.b_mark_shared_group,
59+
b_seq_len=infer_state.b_seq_len,
60+
max_len_in_batch=infer_state.max_len_in_batch,
61+
mid_out=mid_o,
62+
mid_out_logsumexp=mid_o_logexpsum,
63+
BLOCK_SEQ=BLOCK_SEQ,
64+
max_batch_group_size=get_diverse_max_batch_shared_group_size(),
65+
)
66+
stream2.wait_stream(current_stream)
67+
with torch.cuda.stream(stream2):
68+
light_ops.group8_int8kv_flashdecoding_stage1(
69+
BLOCK_SEQ,
70+
mid_o,
71+
mid_o_logexpsum,
72+
1.0 / (head_dim ** 0.5),
73+
q.view(calcu_shape1),
74+
cache_k,
75+
cache_k_scale,
76+
cache_v,
77+
cache_v_scale,
78+
infer_state.req_manager.req_to_token_indexs,
79+
infer_state.b_req_idx,
80+
infer_state.b_seq_len,
81+
infer_state.max_len_in_batch,
82+
)
83+
84+
current_stream.wait_stream(stream1)
85+
current_stream.wait_stream(stream2)
86+
87+
flash_diverse_decode_stage3(
88+
mid_out=mid_o,
89+
mid_out_logexpsum=mid_o_logexpsum,
90+
B_Seqlen=infer_state.b_seq_len,
91+
b_shared_seq_len=infer_state.b_shared_seq_len,
92+
O=o_tensor.view(calcu_shape1),
93+
block_seq=BLOCK_SEQ,
94+
)
5195
return o_tensor
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def _fwd_kernel_flash_diverse_decode_stage3(
8+
B_Seqlen,
9+
b_shared_seq_len,
10+
Mid_O, # [batch, head, seq_block_num, head_dim]
11+
Mid_O_LogExpSum, # [batch, head, seq_block_num]
12+
O, # [batch, head, head_dim]
13+
stride_mid_ob,
14+
stride_mid_oh,
15+
stride_mid_os,
16+
stride_mid_od,
17+
stride_mid_o_eb,
18+
stride_mid_o_eh,
19+
stride_mid_o_es,
20+
stride_obs,
21+
stride_oh,
22+
stride_od,
23+
BLOCK_SEQ: tl.constexpr,
24+
BLOCK_DMODEL: tl.constexpr,
25+
):
26+
cur_batch = tl.program_id(0)
27+
cur_head = tl.program_id(1)
28+
29+
offs_d = tl.arange(0, BLOCK_DMODEL)
30+
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
31+
cur_batch_shared_len = tl.load(b_shared_seq_len + cur_batch)
32+
33+
shared_block_n = tl.cdiv(cur_batch_shared_len, BLOCK_SEQ)
34+
not_shared_block_n = tl.cdiv(cur_batch_seq_len - cur_batch_shared_len, BLOCK_SEQ)
35+
36+
block_n_size = shared_block_n + not_shared_block_n
37+
38+
sum_exp = 0.0
39+
max_logic = -float("inf")
40+
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
41+
42+
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
43+
offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh
44+
for block_seq_n in range(0, block_n_size, 1):
45+
tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)
46+
tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)
47+
new_max_logic = tl.maximum(tlogic, max_logic)
48+
49+
old_scale = tl.exp(max_logic - new_max_logic)
50+
acc *= old_scale
51+
exp_logic = tl.exp(tlogic - new_max_logic)
52+
acc += exp_logic * tv
53+
sum_exp = sum_exp * old_scale + exp_logic
54+
max_logic = new_max_logic
55+
56+
tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)
57+
return
58+
59+
60+
@torch.no_grad()
61+
def flash_diverse_decode_stage3(
62+
mid_out: torch.Tensor,
63+
mid_out_logexpsum: torch.Tensor,
64+
B_Seqlen: torch.Tensor,
65+
b_shared_seq_len: torch.Tensor,
66+
O: torch.Tensor,
67+
block_seq: int,
68+
):
69+
Lk = mid_out.shape[-1]
70+
assert Lk in {16, 32, 64, 128}
71+
batch, head_num = mid_out.shape[0], mid_out.shape[1]
72+
grid = (batch, head_num)
73+
74+
_fwd_kernel_flash_diverse_decode_stage3[grid](
75+
B_Seqlen=B_Seqlen,
76+
b_shared_seq_len=b_shared_seq_len,
77+
Mid_O=mid_out,
78+
Mid_O_LogExpSum=mid_out_logexpsum,
79+
O=O,
80+
stride_mid_ob=mid_out.stride(0),
81+
stride_mid_oh=mid_out.stride(1),
82+
stride_mid_os=mid_out.stride(2),
83+
stride_mid_od=mid_out.stride(3),
84+
stride_mid_o_eb=mid_out_logexpsum.stride(0),
85+
stride_mid_o_eh=mid_out_logexpsum.stride(1),
86+
stride_mid_o_es=mid_out_logexpsum.stride(2),
87+
stride_obs=O.stride(0),
88+
stride_oh=O.stride(1),
89+
stride_od=O.stride(2),
90+
BLOCK_SEQ=block_seq,
91+
BLOCK_DMODEL=Lk,
92+
num_warps=4,
93+
num_stages=2,
94+
)
95+
return

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
144144
# b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于
145145
# 同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。某些模式可以利用这两个
146146
# 输入加速算子的运行。
147-
b_shared_seq_len = torch.tensor(b_shared_seq_len, dtype=torch.int32, device="cpu")
148147
b_mark_shared_group = []
149148
shared_nodes = [req.shared_kv_node for req in run_reqs]
150149
_current_group = []
@@ -169,6 +168,13 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
169168
_current_group.clear()
170169

171170
assert len(b_mark_shared_group) == len(run_reqs)
171+
# 如果一个 shared group 的长度为1, 则将其共享长度强制修改为0, 避免无效计算,提升
172+
# 算子执行效率。
173+
b_shared_seq_len = [
174+
0 if group_size == 1 else shared_len
175+
for shared_len, group_size in zip(b_shared_seq_len, b_mark_shared_group)
176+
]
177+
b_shared_seq_len = torch.tensor(b_shared_seq_len, dtype=torch.int32, device="cpu")
172178
b_mark_shared_group = torch.tensor(b_mark_shared_group, dtype=torch.int32, device="cpu")
173179
else:
174180
b_shared_seq_len = None
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
import torch
3+
from lightllm.models.llama.triton_kernel.ppl_int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3
4+
5+
6+
@pytest.mark.parametrize(
7+
"batch, head_num, seq_len, shared_seq_len, block_seq, head_dim",
8+
[
9+
(2, 4, 256, 256, 256, 128),
10+
(1, 8, 256 * 2, 256, 256, 128),
11+
(3, 2, 256 * 4, 256 * 2, 256, 128),
12+
],
13+
)
14+
def test_flash_diverse_decode_stage3(batch, head_num, seq_len, shared_seq_len, block_seq, head_dim):
15+
# Initialize inputs
16+
mid_out = torch.randn(batch, head_num, seq_len // block_seq + 2, head_dim, dtype=torch.bfloat16, device="cuda")
17+
mid_out_logexpsum = torch.randn(batch, head_num, seq_len // block_seq + 2, dtype=torch.bfloat16, device="cuda")
18+
B_Seqlen = torch.tensor([seq_len] * batch, dtype=torch.int32, device="cuda")
19+
b_shared_seq_len = torch.tensor([shared_seq_len] * batch, dtype=torch.int32, device="cuda")
20+
out = torch.zeros(batch, head_num, head_dim, dtype=torch.float32, device="cuda")
21+
22+
# Call the function
23+
flash_diverse_decode_stage3(mid_out, mid_out_logexpsum, B_Seqlen, b_shared_seq_len, out, block_seq)
24+
25+
true_out = torch.zeros_like(out)
26+
from lightllm.models.llama.triton_kernel.flash_decoding_stage2 import flash_decode_stage2
27+
28+
flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, true_out, block_seq)
29+
30+
assert torch.allclose(out, true_out, atol=1e-2)

0 commit comments

Comments
 (0)