Skip to content

Commit 15f3a5c

Browse files
author
niushengxiao
committed
feat: add triton_fp8kv mode for deepseek2
1 parent b8cfd70 commit 15f3a5c

File tree

10 files changed

+858
-20
lines changed

10 files changed

+858
-20
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from .deepseek2_mem_manager import Deepseek2MemoryManager
3+
4+
5+
class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager):
6+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
7+
# scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8
8+
super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction)

lightllm/common/mem_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def select_mem_manager_class(mode):
1818
elif "triton_int8kv" in mode:
1919
memory_manager_class = INT8KVMemoryManager
2020
logger.info("Model kv cache using mode triton int8kv")
21+
elif "triton_fp8kv" in mode:
22+
raise Exception("currently only for deepseek")
2123
else:
2224
memory_manager_class = MemoryManager
2325
logger.info("Model kv cache using mode normal")

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 167 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
from lightllm.models.deepseek2.triton_kernel.destindex_copy_kv import destindex_copy_kv
99
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad import (
1010
context_attention_fwd,
11+
context_attention_fwd_fp8,
1112
context_attention_fwd_no_prompt_cache,
1213
)
1314
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v
1415
from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv
1516

16-
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
17+
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import (
18+
gqa_token_decode_attention_flash_decoding,
19+
gqa_token_decode_attention_flash_decoding_fp8,
20+
)
1721
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
1822
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1923
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
@@ -22,6 +26,7 @@
2226
from functools import partial
2327
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2428
import os
29+
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
2530

2631

2732
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
@@ -67,19 +72,12 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6772
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
6873
return
6974

70-
def _bind_attention(self):
71-
if self.enable_cc_method:
72-
self._context_attention_kernel = partial(
73-
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
74-
)
75-
else:
76-
self._context_attention_kernel = partial(
77-
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
78-
)
79-
self._token_attention_kernel = partial(
80-
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
81-
)
82-
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
75+
def _bind_func(self):
76+
super()._bind_func()
77+
self._bind_ffn()
78+
return
79+
80+
def _bind_ffn(self):
8381
if self.is_moe:
8482
if self.enable_dp:
8583
if os.environ.get("MOE_MODE", "TP") == "TP":
@@ -92,6 +90,36 @@ def _bind_attention(self):
9290
else:
9391
self._ffn = partial(LlamaTransformerLayerInfer._ffn, self)
9492

93+
def _bind_attention(self):
94+
if "triton_fp8kv" in self.mode:
95+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self)
96+
self._token_attention_kernel = partial(
97+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self
98+
)
99+
else:
100+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
101+
self._token_attention_kernel = partial(
102+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
103+
)
104+
if self.enable_cc_method:
105+
if "triton_fp8kv" in self.mode:
106+
self._context_attention_kernel = partial(
107+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
108+
)
109+
else:
110+
self._context_attention_kernel = partial(
111+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
112+
)
113+
else:
114+
if "triton_fp8kv" in self.mode:
115+
self._context_attention_kernel = partial(
116+
Deepseek2TransformerLayerInfer._context_attention_kernel_origin_fp8, self
117+
)
118+
else:
119+
self._context_attention_kernel = partial(
120+
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
121+
)
122+
95123
def _get_qkv(
96124
self,
97125
input: torch.Tensor,
@@ -133,9 +161,19 @@ def _get_o(
133161
o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim))
134162
return o_tensor
135163

136-
def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight):
164+
def _decompress_kv(
165+
self, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, is_fp8
166+
):
137167
if infer_state.use_dynamic_prompt_cache:
138-
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
168+
if is_fp8:
169+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
170+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
171+
k_scale = self.alloc_tensor([infer_state.total_token_num, 1], dtype=kv_scale.dtype)
172+
else:
173+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
174+
kv_scale = None
175+
k_scale = None
176+
139177
compressed_kv = self.alloc_tensor(
140178
[infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype
141179
)
@@ -147,7 +185,12 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
147185
infer_state.b_req_idx,
148186
infer_state.b_seq_len,
149187
infer_state.req_manager.req_to_token_indexs,
188+
kv_scale,
189+
k_scale,
150190
)
191+
if k_scale is not None:
192+
compressed_kv = compressed_kv.to(k_scale.dtype) * k_scale.unsqueeze(-1)
193+
k_rope = k_rope.to(k_scale.dtype) * k_scale.unsqueeze(-1)
151194
else:
152195
compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r)
153196
kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
@@ -177,7 +220,33 @@ def _context_attention_kernel_with_CC(
177220
layer_weight: Deepseek2TransformerLayerWeight,
178221
out=None,
179222
) -> torch.Tensor:
180-
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight)
223+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
224+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
225+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
226+
context_attention_fwd_with_v(
227+
q_nope,
228+
q_rope,
229+
k_nope,
230+
k_rope,
231+
v,
232+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
233+
infer_state.b_start_loc,
234+
infer_state.b_seq_len,
235+
infer_state.b_ready_cache_len,
236+
infer_state.max_len_in_batch,
237+
self.softmax_scale,
238+
)
239+
return o_tensor
240+
241+
def _context_attention_kernel_with_CC_fp8(
242+
self,
243+
q: torch.Tensor,
244+
kv,
245+
infer_state: Deepseek2InferStateInfo,
246+
layer_weight: Deepseek2TransformerLayerWeight,
247+
out=None,
248+
) -> torch.Tensor:
249+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
181250
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
182251
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
183252
context_attention_fwd_with_v(
@@ -237,6 +306,50 @@ def _context_attention_kernel_origin(
237306

238307
return o_tensor
239308

309+
def _context_attention_kernel_origin_fp8(
310+
self,
311+
q: torch.Tensor,
312+
kv,
313+
infer_state: Deepseek2InferStateInfo,
314+
layer_weight: Deepseek2TransformerLayerWeight,
315+
out=None,
316+
) -> torch.Tensor:
317+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
318+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
319+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
320+
if infer_state.use_dynamic_prompt_cache:
321+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
322+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
323+
context_attention_fwd_fp8(
324+
q_nope,
325+
q_rope,
326+
kv[:, :, : -self.qk_rope_head_dim],
327+
kv[:, :, -self.qk_rope_head_dim :],
328+
kv_scale,
329+
o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank),
330+
infer_state.b_req_idx,
331+
infer_state.b_start_loc,
332+
infer_state.b_seq_len,
333+
infer_state.b_ready_cache_len,
334+
infer_state.max_len_in_batch,
335+
infer_state.req_manager.req_to_token_indexs,
336+
self.softmax_scale,
337+
)
338+
else:
339+
context_attention_fwd_no_prompt_cache(
340+
q_nope,
341+
q_rope,
342+
kv[:, :, : -self.qk_rope_head_dim],
343+
kv[:, :, -self.qk_rope_head_dim :],
344+
o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank),
345+
infer_state.b_start_loc,
346+
infer_state.b_seq_len,
347+
infer_state.max_len_in_batch,
348+
self.softmax_scale,
349+
)
350+
351+
return o_tensor
352+
240353
def _token_gqa_decode_attention_flashdecoding(
241354
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
242355
):
@@ -279,6 +392,29 @@ def _token_gqa_decode_attention_flashdecoding(
279392
alloc_tensor_func=self.alloc_tensor,
280393
)
281394

395+
def _token_gqa_decode_attention_flashdecoding_fp8(
396+
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
397+
):
398+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
399+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
400+
401+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
402+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
403+
return gqa_token_decode_attention_flash_decoding_fp8(
404+
q_nope,
405+
q_rope,
406+
kv[:, :, : -self.qk_rope_head_dim],
407+
kv[:, :, -self.qk_rope_head_dim :],
408+
kv_scale,
409+
infer_state,
410+
self.tp_q_head_num_,
411+
self.kv_lora_rank,
412+
self.qk_rope_head_dim,
413+
self.qk_nope_head_dim,
414+
self.softmax_scale,
415+
alloc_tensor_func=self.alloc_tensor,
416+
)
417+
282418
def _splitfuse_attention_kernel(
283419
self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None
284420
) -> torch.Tensor:
@@ -321,6 +457,20 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
321457
)
322458
return
323459

460+
def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager):
461+
quant_method = vLLMFP8w8a8QuantizationMethod()
462+
quant, scale = quant_method.quantize_scaled_mm_fp8(buffer.reshape(-1, buffer.shape[-1]))
463+
destindex_copy_kv(
464+
quant.T.unsqueeze(1)[:, :, : self.kv_lora_rank].view(torch.uint8),
465+
quant.T.unsqueeze(1)[:, :, self.kv_lora_rank :].view(torch.uint8),
466+
mem_index,
467+
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank],
468+
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2],
469+
mem_manager.kv_buffer[self.layer_num_][:, :, -2:],
470+
scale.to(buffer.dtype).view(torch.uint8),
471+
)
472+
return
473+
324474
def _ffn_dp(
325475
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
326476
) -> torch.Tensor:

lightllm/models/deepseek2/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from lightllm.models.llama.model import LlamaTpPartModel
99
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
10+
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
1011
from lightllm.utils.log_utils import init_logger
1112

1213

@@ -48,7 +49,10 @@ def _verify_params(self):
4849
return super()._verify_params()
4950

5051
def _init_mem_manager(self):
51-
self.mem_manager = Deepseek2MemoryManager(
52+
manager_class = Deepseek2MemoryManager
53+
if "triton_fp8kv" in self.mode:
54+
manager_class = Deepseek2FP8KVMemoryManager
55+
self.mem_manager = manager_class(
5256
self.max_total_token_num,
5357
dtype=self.data_type,
5458
head_num=1,

0 commit comments

Comments
 (0)