Skip to content

Commit 65e2747

Browse files
authored
feat: add deepseekv2_bf16kv and deepseekv2_fp8kv modes (#712)
1 parent a740b7f commit 65e2747

File tree

10 files changed

+1040
-19
lines changed

10 files changed

+1040
-19
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from .deepseek2_mem_manager import Deepseek2MemoryManager
3+
4+
5+
class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager):
6+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
7+
# scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8
8+
super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction)

lightllm/common/mem_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def select_mem_manager_class(mode):
1818
elif "triton_int8kv" in mode:
1919
memory_manager_class = INT8KVMemoryManager
2020
logger.info("Model kv cache using mode triton int8kv")
21+
elif "triton_fp8kv" in mode:
22+
raise Exception("currently only for deepseek")
2123
else:
2224
memory_manager_class = MemoryManager
2325
logger.info("Model kv cache using mode normal")

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
context_attention_fwd,
1111
context_attention_fwd_no_prompt_cache,
1212
)
13+
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_fp8 import context_attention_fwd_fp8
1314
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v
1415
from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv
1516

1617
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
18+
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8
1719
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
1820
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1921
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
@@ -22,6 +24,7 @@
2224
from functools import partial
2325
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2426
import os
27+
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
2528

2629

2730
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
@@ -67,19 +70,12 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6770
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
6871
return
6972

70-
def _bind_attention(self):
71-
if self.enable_cc_method:
72-
self._context_attention_kernel = partial(
73-
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
74-
)
75-
else:
76-
self._context_attention_kernel = partial(
77-
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
78-
)
79-
self._token_attention_kernel = partial(
80-
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
81-
)
82-
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
73+
def _bind_func(self):
74+
super()._bind_func()
75+
self._bind_ffn()
76+
return
77+
78+
def _bind_ffn(self):
8379
if self.is_moe:
8480
if self.enable_dp:
8581
if os.environ.get("MOE_MODE", "TP") == "TP":
@@ -92,6 +88,36 @@ def _bind_attention(self):
9288
else:
9389
self._ffn = partial(LlamaTransformerLayerInfer._ffn, self)
9490

91+
def _bind_attention(self):
92+
if "triton_fp8kv" in self.mode:
93+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self)
94+
self._token_attention_kernel = partial(
95+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self
96+
)
97+
else:
98+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
99+
self._token_attention_kernel = partial(
100+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
101+
)
102+
if self.enable_cc_method:
103+
if "triton_fp8kv" in self.mode:
104+
self._context_attention_kernel = partial(
105+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
106+
)
107+
else:
108+
self._context_attention_kernel = partial(
109+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
110+
)
111+
else:
112+
if "triton_fp8kv" in self.mode:
113+
self._context_attention_kernel = partial(
114+
Deepseek2TransformerLayerInfer._context_attention_kernel_origin_fp8, self
115+
)
116+
else:
117+
self._context_attention_kernel = partial(
118+
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
119+
)
120+
95121
def _get_qkv(
96122
self,
97123
input: torch.Tensor,
@@ -133,9 +159,19 @@ def _get_o(
133159
o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim))
134160
return o_tensor
135161

136-
def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight):
162+
def _decompress_kv(
163+
self, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, is_fp8
164+
):
137165
if infer_state.use_dynamic_prompt_cache:
138-
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
166+
if is_fp8:
167+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
168+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
169+
k_scale = self.alloc_tensor([infer_state.total_token_num, 1], dtype=kv_scale.dtype)
170+
else:
171+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
172+
kv_scale = None
173+
k_scale = None
174+
139175
compressed_kv = self.alloc_tensor(
140176
[infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype
141177
)
@@ -147,7 +183,12 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
147183
infer_state.b_req_idx,
148184
infer_state.b_seq_len,
149185
infer_state.req_manager.req_to_token_indexs,
186+
kv_scale,
187+
k_scale,
150188
)
189+
if k_scale is not None:
190+
compressed_kv = compressed_kv.to(k_scale.dtype) * k_scale.unsqueeze(-1)
191+
k_rope = k_rope.to(k_scale.dtype) * k_scale.unsqueeze(-1)
151192
else:
152193
compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r)
153194
kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
@@ -175,7 +216,33 @@ def _context_attention_kernel_with_CC(
175216
layer_weight: Deepseek2TransformerLayerWeight,
176217
out=None,
177218
) -> torch.Tensor:
178-
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight)
219+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
220+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
221+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
222+
context_attention_fwd_with_v(
223+
q_nope,
224+
q_rope,
225+
k_nope,
226+
k_rope,
227+
v,
228+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
229+
infer_state.b_start_loc,
230+
infer_state.b_seq_len,
231+
infer_state.b_ready_cache_len,
232+
infer_state.max_len_in_batch,
233+
self.softmax_scale,
234+
)
235+
return o_tensor
236+
237+
def _context_attention_kernel_with_CC_fp8(
238+
self,
239+
q: torch.Tensor,
240+
kv,
241+
infer_state: Deepseek2InferStateInfo,
242+
layer_weight: Deepseek2TransformerLayerWeight,
243+
out=None,
244+
) -> torch.Tensor:
245+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
179246
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
180247
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
181248
context_attention_fwd_with_v(
@@ -235,6 +302,50 @@ def _context_attention_kernel_origin(
235302

236303
return o_tensor
237304

305+
def _context_attention_kernel_origin_fp8(
306+
self,
307+
q: torch.Tensor,
308+
kv,
309+
infer_state: Deepseek2InferStateInfo,
310+
layer_weight: Deepseek2TransformerLayerWeight,
311+
out=None,
312+
) -> torch.Tensor:
313+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
314+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
315+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
316+
if infer_state.use_dynamic_prompt_cache:
317+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
318+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
319+
context_attention_fwd_fp8(
320+
q_nope,
321+
q_rope,
322+
kv[:, :, : -self.qk_rope_head_dim],
323+
kv[:, :, -self.qk_rope_head_dim :],
324+
kv_scale,
325+
o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank),
326+
infer_state.b_req_idx,
327+
infer_state.b_start_loc,
328+
infer_state.b_seq_len,
329+
infer_state.b_ready_cache_len,
330+
infer_state.max_len_in_batch,
331+
infer_state.req_manager.req_to_token_indexs,
332+
self.softmax_scale,
333+
)
334+
else:
335+
context_attention_fwd_no_prompt_cache(
336+
q_nope,
337+
q_rope,
338+
kv[:, :, : -self.qk_rope_head_dim],
339+
kv[:, :, -self.qk_rope_head_dim :],
340+
o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank),
341+
infer_state.b_start_loc,
342+
infer_state.b_seq_len,
343+
infer_state.max_len_in_batch,
344+
self.softmax_scale,
345+
)
346+
347+
return o_tensor
348+
238349
def _token_gqa_decode_attention_flashdecoding(
239350
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
240351
):
@@ -277,6 +388,29 @@ def _token_gqa_decode_attention_flashdecoding(
277388
alloc_tensor_func=self.alloc_tensor,
278389
)
279390

391+
def _token_gqa_decode_attention_flashdecoding_fp8(
392+
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
393+
):
394+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
395+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
396+
397+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
398+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
399+
return gqa_token_decode_attention_flash_decoding_fp8(
400+
q_nope,
401+
q_rope,
402+
kv[:, :, : -self.qk_rope_head_dim],
403+
kv[:, :, -self.qk_rope_head_dim :],
404+
kv_scale,
405+
infer_state,
406+
self.tp_q_head_num_,
407+
self.kv_lora_rank,
408+
self.qk_rope_head_dim,
409+
self.qk_nope_head_dim,
410+
self.softmax_scale,
411+
alloc_tensor_func=self.alloc_tensor,
412+
)
413+
280414
def _splitfuse_attention_kernel(
281415
self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None
282416
) -> torch.Tensor:
@@ -319,6 +453,20 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
319453
)
320454
return
321455

456+
def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager):
457+
quant_method = vLLMFP8w8a8QuantizationMethod()
458+
quant, scale = quant_method.quantize_scaled_mm_fp8(buffer.reshape(-1, buffer.shape[-1]))
459+
destindex_copy_kv(
460+
quant.T.unsqueeze(1)[:, :, : self.kv_lora_rank].view(torch.uint8),
461+
quant.T.unsqueeze(1)[:, :, self.kv_lora_rank :].view(torch.uint8),
462+
mem_index,
463+
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank],
464+
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2],
465+
mem_manager.kv_buffer[self.layer_num_][:, :, -2:],
466+
scale.to(buffer.dtype).view(torch.uint8),
467+
)
468+
return
469+
322470
def _ffn_dp(
323471
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
324472
) -> torch.Tensor:

lightllm/models/deepseek2/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from lightllm.models.llama.model import LlamaTpPartModel
99
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
10+
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
1011
from lightllm.utils.log_utils import init_logger
1112

1213

@@ -48,7 +49,10 @@ def _verify_params(self):
4849
return super()._verify_params()
4950

5051
def _init_mem_manager(self):
51-
self.mem_manager = Deepseek2MemoryManager(
52+
manager_class = Deepseek2MemoryManager
53+
if "triton_fp8kv" in self.mode:
54+
manager_class = Deepseek2FP8KVMemoryManager
55+
self.mem_manager = manager_class(
5256
self.max_total_token_num,
5357
dtype=self.data_type,
5458
head_num=1,

0 commit comments

Comments
 (0)