Skip to content

Commit dafc5ed

Browse files
author
niushengxiao
committed
feat: add deepseekv2_bf16kv and deepseekv2_fp8kv modes
1 parent c757cf5 commit dafc5ed

File tree

9 files changed

+112
-7
lines changed

9 files changed

+112
-7
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from .deepseek2_mem_manager import Deepseek2MemoryManager
3+
4+
5+
class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager):
6+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
7+
# scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8
8+
super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction)

lightllm/common/mem_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
44
from lightllm.common.ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
55
from lightllm.utils.log_utils import init_logger
6+
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
7+
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
8+
69

710
logger = init_logger(__name__)
811

@@ -18,6 +21,12 @@ def select_mem_manager_class(mode):
1821
elif "triton_int8kv" in mode:
1922
memory_manager_class = INT8KVMemoryManager
2023
logger.info("Model kv cache using mode triton int8kv")
24+
elif "deepseek2_bf16kv" in mode:
25+
memory_manager_class = Deepseek2MemoryManager
26+
logger.info("Model kv cache using mode deepseek2 bf16kv")
27+
elif "deepseek2_fp8kv" in mode:
28+
memory_manager_class = Deepseek2FP8KVMemoryManager
29+
logger.info("Model kv cache using mode deepseek2 fp8kv")
2130
else:
2231
memory_manager_class = MemoryManager
2332
logger.info("Model kv cache using mode normal")

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from functools import partial
2323
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2424
import os
25+
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
2526

2627

2728
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
@@ -68,6 +69,10 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6869
return
6970

7071
def _bind_attention(self):
72+
if "deepseek2_bf16kv" in self.mode:
73+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
74+
else:
75+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self)
7176
if self.enable_cc_method:
7277
self._context_attention_kernel = partial(
7378
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
@@ -79,7 +84,6 @@ def _bind_attention(self):
7984
self._token_attention_kernel = partial(
8085
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
8186
)
82-
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
8387
if self.is_moe:
8488
if self.enable_dp:
8589
if os.environ.get("MOE_MODE", "TP") == "TP":
@@ -135,7 +139,15 @@ def _get_o(
135139

136140
def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight):
137141
if infer_state.use_dynamic_prompt_cache:
138-
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
142+
if "deepseek2_bf16kv" in self.mode:
143+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
144+
kv_scale = None
145+
k_scale = None
146+
else:
147+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
148+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
149+
k_scale = self.alloc_tensor([infer_state.total_token_num, 1], dtype=kv_scale.dtype)
150+
139151
compressed_kv = self.alloc_tensor(
140152
[infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype
141153
)
@@ -147,7 +159,12 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
147159
infer_state.b_req_idx,
148160
infer_state.b_seq_len,
149161
infer_state.req_manager.req_to_token_indexs,
162+
kv_scale,
163+
k_scale,
150164
)
165+
if k_scale is not None:
166+
compressed_kv = compressed_kv.to(k_scale.dtype) * k_scale.unsqueeze(-1)
167+
k_rope = k_rope.to(k_scale.dtype) * k_scale.unsqueeze(-1)
151168
else:
152169
compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r)
153170
kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
@@ -264,12 +281,18 @@ def _token_gqa_decode_attention_flashdecoding(
264281
)
265282
return o_tensor
266283
else:
267-
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
284+
if "deepseek2_bf16kv" in self.mode:
285+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
286+
kv_scale = None
287+
else:
288+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
289+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
268290
return gqa_token_decode_attention_flash_decoding(
269291
q_nope,
270292
q_rope,
271293
kv[:, :, : -self.qk_rope_head_dim],
272294
kv[:, :, -self.qk_rope_head_dim :],
295+
kv_scale,
273296
infer_state,
274297
self.tp_q_head_num_,
275298
self.kv_lora_rank,
@@ -321,6 +344,20 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
321344
)
322345
return
323346

347+
def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager):
348+
quant_method = vLLMFP8w8a8QuantizationMethod()
349+
quant, scale = quant_method.quantize_scaled_mm_fp8(buffer.reshape(-1, buffer.shape[-1]))
350+
destindex_copy_kv(
351+
quant.T.unsqueeze(1)[:, :, : self.kv_lora_rank].view(torch.uint8),
352+
quant.T.unsqueeze(1)[:, :, self.kv_lora_rank :].view(torch.uint8),
353+
mem_index,
354+
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank],
355+
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2],
356+
mem_manager.kv_buffer[self.layer_num_][:, :, -2:],
357+
scale.to(buffer.dtype).view(torch.uint8),
358+
)
359+
return
360+
324361
def _ffn_dp(
325362
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
326363
) -> torch.Tensor:

lightllm/models/deepseek2/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
from lightllm.models.llama.model import LlamaTpPartModel
99
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
10+
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
1011
from lightllm.utils.log_utils import init_logger
12+
from lightllm.common.mem_utils import select_mem_manager_class
1113

1214

1315
logger = init_logger(__name__)
@@ -48,14 +50,17 @@ def _verify_params(self):
4850
return super()._verify_params()
4951

5052
def _init_mem_manager(self):
51-
self.mem_manager = Deepseek2MemoryManager(
53+
self.mem_manager = select_mem_manager_class(self.mode)(
5254
self.max_total_token_num,
5355
dtype=self.data_type,
5456
head_num=1,
5557
head_dim=self.config["kv_lora_rank"] + self.config["qk_rope_head_dim"],
5658
layer_num=self.config["num_hidden_layers"],
5759
mem_fraction=self.mem_fraction,
5860
)
61+
assert isinstance(self.mem_manager, Deepseek2MemoryManager) or isinstance(
62+
self.mem_manager, Deepseek2FP8KVMemoryManager
63+
)
5964
return
6065

6166
def _init_weights(self):

lightllm/models/deepseek2/triton_kernel/destindex_copy_kv.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ def _is_power_of_two(n):
1212
def _fwd_kernel_destindex_copy_kv(
1313
KV_nope,
1414
KV_rope,
15+
KV_scale,
1516
Dest_loc,
1617
O_nope,
1718
O_rope,
19+
O_scale,
1820
stride_kv_nope_bs,
1921
stride_kv_nope_h,
2022
stride_kv_nope_d,
@@ -29,6 +31,7 @@ def _fwd_kernel_destindex_copy_kv(
2931
stride_o_rope_d,
3032
kv_nope_head_num,
3133
kv_rope_head_num,
34+
HAS_SCALE: tl.constexpr,
3235
BLOCK_DMODEL_NOPE: tl.constexpr,
3336
BLOCK_DMODEL_ROPE: tl.constexpr,
3437
):
@@ -47,13 +50,20 @@ def _fwd_kernel_destindex_copy_kv(
4750
kv_nope = tl.load(kv_nope_ptrs)
4851
kv_rope = tl.load(kv_rope_ptrs)
4952

53+
if HAS_SCALE:
54+
offs_d_scale = tl.arange(0, 2)
55+
o_scale_ptrs = O_scale + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_scale[None, :]
56+
kv_scale_ptrs = KV_scale + cur_index * 2 + offs_d_scale[None, :]
57+
kv_scale = tl.load(kv_scale_ptrs)
58+
tl.store(o_scale_ptrs, kv_scale)
59+
5060
tl.store(o_nope_ptrs, kv_nope)
5161
tl.store(o_rope_ptrs, kv_rope)
5262
return
5363

5464

5565
@torch.no_grad()
56-
def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope):
66+
def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope, O_scale=None, KV_scale=None):
5767
seq_len = DestLoc.shape[0]
5868
kv_nope_head_num = KV_nope.shape[1]
5969
kv_rope_head_num = KV_rope.shape[1]
@@ -71,9 +81,11 @@ def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope):
7181
_fwd_kernel_destindex_copy_kv[grid](
7282
KV_nope,
7383
KV_rope,
84+
KV_scale,
7485
DestLoc,
7586
O_nope,
7687
O_rope,
88+
O_scale,
7789
KV_nope.stride(0),
7890
KV_nope.stride(1),
7991
KV_nope.stride(2),
@@ -88,6 +100,7 @@ def destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope):
88100
O_rope.stride(2),
89101
kv_nope_head_num,
90102
kv_rope_head_num,
103+
HAS_SCALE=KV_scale is not None,
91104
BLOCK_DMODEL_NOPE=kv_nope_head_dim,
92105
BLOCK_DMODEL_ROPE=kv_rope_head_dim,
93106
num_warps=num_warps,

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def gqa_token_decode_attention_flash_decoding(
1616
q_rope,
1717
kv_nope,
1818
kv_rope,
19+
kv_scale,
1920
infer_state,
2021
q_head_num,
2122
kv_lora_rank,
@@ -63,6 +64,7 @@ def gqa_token_decode_attention_flash_decoding(
6364
q_rope.view(calcu_shape2),
6465
kv_nope,
6566
kv_rope,
67+
kv_scale,
6668
infer_state.req_manager.req_to_token_indexs,
6769
infer_state.b_req_idx,
6870
infer_state.b_seq_len,
@@ -111,6 +113,7 @@ def gqa_token_decode_attention_flash_decoding(
111113
q_rope.view(calcu_shape2),
112114
kv_nope,
113115
kv_rope,
116+
kv_scale,
114117
infer_state.req_manager.req_to_token_indexs,
115118
infer_state.b_req_idx,
116119
infer_state.b_seq_len,

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def _fwd_kernel_flash_decode_stage1_padding(
1010
Q_rope,
1111
KV_nope,
1212
KV_rope,
13+
KV_scale,
1314
sm_scale,
1415
Req_to_tokens,
1516
B_req_idx,
@@ -35,11 +36,13 @@ def _fwd_kernel_flash_decode_stage1_padding(
3536
stride_mid_od,
3637
stride_mid_o_eh,
3738
stride_mid_o_es,
39+
stride_kv_scaled_bs,
3840
block_size_ptr,
3941
num_sm,
4042
head_group_num,
4143
head_num,
4244
batch_size,
45+
HAS_SCALE: tl.constexpr,
4346
Q_HEAD_NUM: tl.constexpr,
4447
BLOCK_DMODEL: tl.constexpr,
4548
BLOCK_ROPE_DMODEL: tl.constexpr,
@@ -108,9 +111,14 @@ def _fwd_kernel_flash_decode_stage1_padding(
108111
)
109112
off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None]
110113
kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0)
111-
att_value = tl.dot(q, kv)
112114
off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None]
113115
rope_kv = tl.load(KV_rope + off_rope_kv, mask=seq_n_mask[None, :], other=0.0)
116+
if HAS_SCALE:
117+
off_kv_scale = kv_loc[None, :] * stride_kv_scaled_bs
118+
kv_scale = tl.load(KV_scale + off_kv_scale, mask=seq_n_mask[None, :], other=0.0)
119+
kv = (kv * kv_scale).to(kv_scale.dtype)
120+
rope_kv = (rope_kv * kv_scale).to(kv_scale.dtype)
121+
att_value = tl.dot(q, kv)
114122
att_value += tl.dot(q_rope, rope_kv)
115123

116124
att_value *= sm_scale
@@ -167,6 +175,7 @@ def flash_decode_stage1(
167175
q_rope,
168176
kv_nope,
169177
kv_rope,
178+
kv_scale,
170179
Req_to_tokens,
171180
B_req_idx,
172181
B_Seqlen,
@@ -201,6 +210,7 @@ def flash_decode_stage1(
201210
q_rope,
202211
kv_nope,
203212
kv_rope,
213+
kv_scale,
204214
softmax_scale,
205215
Req_to_tokens,
206216
B_req_idx,
@@ -214,11 +224,13 @@ def flash_decode_stage1(
214224
*kv_rope.stride(),
215225
*mid_out.stride(),
216226
*mid_out_logsumexp.stride(),
227+
kv_scale.stride(0) if kv_scale is not None else 0,
217228
in_block_seq,
218229
num_sm=1,
219230
head_group_num=head_group_num,
220231
head_num=q_head_num,
221232
batch_size=batch_size,
233+
HAS_SCALE=1 if kv_scale is not None else 0,
222234
Q_HEAD_NUM=Q_HEAD_NUM,
223235
BLOCK_DMODEL=q_nope_dim,
224236
BLOCK_ROPE_DMODEL=q_rope_dim,
@@ -243,6 +255,7 @@ def flash_decode_stage1(
243255
q_rope,
244256
kv_nope,
245257
kv_rope,
258+
kv_scale,
246259
softmax_scale,
247260
Req_to_tokens,
248261
B_req_idx,
@@ -256,11 +269,13 @@ def flash_decode_stage1(
256269
*kv_rope.stride(),
257270
*mid_out.stride(),
258271
*mid_out_logsumexp.stride(),
272+
kv_scale.stride(0) if kv_scale is not None else 0,
259273
in_block_seq,
260274
num_sm=num_sm,
261275
head_group_num=head_group_num,
262276
head_num=q_head_num,
263277
batch_size=batch_size,
278+
HAS_SCALE=1 if kv_scale is not None else 0,
264279
Q_HEAD_NUM=Q_HEAD_NUM,
265280
BLOCK_DMODEL=q_nope_dim,
266281
BLOCK_ROPE_DMODEL=q_rope_dim,

lightllm/models/deepseek2/triton_kernel/sample_kv.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@
1010
@triton.jit
1111
def _sample_kv_kernel(
1212
KV_input,
13+
KV_scale,
1314
KV_nope,
1415
KV_rope,
16+
K_scale,
1517
B_start_loc,
1618
B_Seqlen,
1719
Req_to_tokens,
1820
B_req_idx,
1921
stride_input_dim,
22+
stride_scale_dim,
2023
stride_nope_dim,
2124
stride_rope_dim,
2225
stride_req_to_tokens_b,
26+
HAS_SCALE: tl.constexpr,
2327
BLOCK_M: tl.constexpr,
2428
BLOCK_DMODEL: tl.constexpr,
2529
BLOCK_ROPE_DMODEL: tl.constexpr,
@@ -52,6 +56,11 @@ def _sample_kv_kernel(
5256
rope_ptrs = KV_rope + off_rope
5357
tl.store(nope_ptrs, kv_nope, mask=offs_m[:, None] < block_end_loc)
5458
tl.store(rope_ptrs, kv_rope, mask=offs_m[:, None] < block_end_loc)
59+
if HAS_SCALE:
60+
kv_scale = tl.load(KV_scale + kv_loc * stride_scale_dim, mask=offs_m < block_end_loc)
61+
off_k_scale = cur_batch_start_loc + offs_m
62+
k_scale_ptrs = K_scale + off_k_scale
63+
tl.store(k_scale_ptrs, kv_scale, mask=offs_m < block_end_loc)
5564
return
5665

5766

@@ -63,6 +72,8 @@ def sample_kv(
6372
b_req_idx,
6473
b_seq_len,
6574
req_to_token_indexs,
75+
kv_scale=None,
76+
k_scale=None,
6677
):
6778
BLOCK = 128 if not TESLA else 64
6879

@@ -85,16 +96,20 @@ def sample_kv(
8596
b_start_loc = torch.cat([torch.zeros([1], device=b_seq_len.device, dtype=b_seq_len.dtype), b_seq_len[1:].cumsum(0)])
8697
_sample_kv_kernel[grid](
8798
kv_input,
99+
kv_scale,
88100
kv_nope,
89101
kv_rope,
102+
k_scale,
90103
b_start_loc,
91104
b_seq_len,
92105
req_to_token_indexs,
93106
b_req_idx,
94107
kv_input.stride(0),
108+
kv_scale.stride(0) if kv_scale is not None else 0,
95109
kv_nope.stride(0),
96110
kv_rope.stride(0),
97111
req_to_token_indexs.stride(0),
112+
HAS_SCALE=kv_scale is not None,
98113
BLOCK_M=BLOCK,
99114
BLOCK_DMODEL=nope_dim,
100115
BLOCK_ROPE_DMODEL=rope_dim,

0 commit comments

Comments
 (0)