Skip to content

Commit 22db717

Browse files
author
niushengxiao
committed
feat: add triton_fp8kv mode for deepseek2
1 parent 743ddc3 commit 22db717

File tree

10 files changed

+1040
-19
lines changed

10 files changed

+1040
-19
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
from .deepseek2_mem_manager import Deepseek2MemoryManager
3+
4+
5+
class Deepseek2FP8KVMemoryManager(Deepseek2MemoryManager):
6+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
7+
# scale被追加到kv_buffer末尾, 因此加2, dtype统一改成uint8
8+
super().__init__(size, torch.uint8, head_num, head_dim + 2, layer_num, always_copy, mem_fraction)

lightllm/common/mem_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def select_mem_manager_class(mode):
1818
elif "triton_int8kv" in mode:
1919
memory_manager_class = INT8KVMemoryManager
2020
logger.info("Model kv cache using mode triton int8kv")
21+
elif "triton_fp8kv" in mode:
22+
raise Exception("currently only for deepseek")
2123
else:
2224
memory_manager_class = MemoryManager
2325
logger.info("Model kv cache using mode normal")

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
context_attention_fwd,
1111
context_attention_fwd_no_prompt_cache,
1212
)
13+
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_fp8 import context_attention_fwd_fp8
1314
from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v
1415
from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv
1516

1617
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding
18+
from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8
1719
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
1820
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
1921
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
@@ -22,6 +24,7 @@
2224
from functools import partial
2325
from lightllm.models.llama.yarn_rotary_utils import get_deepseek_mscale
2426
import os
27+
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
2528

2629

2730
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
@@ -67,19 +70,12 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
6770
self.enable_opt_decoding_mha = os.getenv("ENABLE_OPT_DECODE_MHA", "False").upper() in ["ON", "TRUE", "1"]
6871
return
6972

70-
def _bind_attention(self):
71-
if self.enable_cc_method:
72-
self._context_attention_kernel = partial(
73-
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
74-
)
75-
else:
76-
self._context_attention_kernel = partial(
77-
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
78-
)
79-
self._token_attention_kernel = partial(
80-
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
81-
)
82-
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
73+
def _bind_func(self):
74+
super()._bind_func()
75+
self._bind_ffn()
76+
return
77+
78+
def _bind_ffn(self):
8379
if self.is_moe:
8480
if self.enable_dp:
8581
if os.environ.get("MOE_MODE", "TP") == "TP":
@@ -92,6 +88,36 @@ def _bind_attention(self):
9288
else:
9389
self._ffn = partial(LlamaTransformerLayerInfer._ffn, self)
9490

91+
def _bind_attention(self):
92+
if "triton_fp8kv" in self.mode:
93+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_fp8, self)
94+
self._token_attention_kernel = partial(
95+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding_fp8, self
96+
)
97+
else:
98+
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
99+
self._token_attention_kernel = partial(
100+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self
101+
)
102+
if self.enable_cc_method:
103+
if "triton_fp8kv" in self.mode:
104+
self._context_attention_kernel = partial(
105+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
106+
)
107+
else:
108+
self._context_attention_kernel = partial(
109+
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self
110+
)
111+
else:
112+
if "triton_fp8kv" in self.mode:
113+
self._context_attention_kernel = partial(
114+
Deepseek2TransformerLayerInfer._context_attention_kernel_origin_fp8, self
115+
)
116+
else:
117+
self._context_attention_kernel = partial(
118+
Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self
119+
)
120+
95121
def _get_qkv(
96122
self,
97123
input: torch.Tensor,
@@ -133,9 +159,19 @@ def _get_o(
133159
o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim))
134160
return o_tensor
135161

136-
def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight):
162+
def _decompress_kv(
163+
self, kv, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, is_fp8
164+
):
137165
if infer_state.use_dynamic_prompt_cache:
138-
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
166+
if is_fp8:
167+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
168+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
169+
k_scale = self.alloc_tensor([infer_state.total_token_num, 1], dtype=kv_scale.dtype)
170+
else:
171+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
172+
kv_scale = None
173+
k_scale = None
174+
139175
compressed_kv = self.alloc_tensor(
140176
[infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype
141177
)
@@ -147,7 +183,12 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
147183
infer_state.b_req_idx,
148184
infer_state.b_seq_len,
149185
infer_state.req_manager.req_to_token_indexs,
186+
kv_scale,
187+
k_scale,
150188
)
189+
if k_scale is not None:
190+
compressed_kv = compressed_kv.to(k_scale.dtype) * k_scale.unsqueeze(-1)
191+
k_rope = k_rope.to(k_scale.dtype) * k_scale.unsqueeze(-1)
151192
else:
152193
compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r)
153194
kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1
@@ -177,7 +218,33 @@ def _context_attention_kernel_with_CC(
177218
layer_weight: Deepseek2TransformerLayerWeight,
178219
out=None,
179220
) -> torch.Tensor:
180-
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight)
221+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
222+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
223+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
224+
context_attention_fwd_with_v(
225+
q_nope,
226+
q_rope,
227+
k_nope,
228+
k_rope,
229+
v,
230+
o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]),
231+
infer_state.b_start_loc,
232+
infer_state.b_seq_len,
233+
infer_state.b_ready_cache_len,
234+
infer_state.max_len_in_batch,
235+
self.softmax_scale,
236+
)
237+
return o_tensor
238+
239+
def _context_attention_kernel_with_CC_fp8(
240+
self,
241+
q: torch.Tensor,
242+
kv,
243+
infer_state: Deepseek2InferStateInfo,
244+
layer_weight: Deepseek2TransformerLayerWeight,
245+
out=None,
246+
) -> torch.Tensor:
247+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, True)
181248
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
182249
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
183250
context_attention_fwd_with_v(
@@ -237,6 +304,50 @@ def _context_attention_kernel_origin(
237304

238305
return o_tensor
239306

307+
def _context_attention_kernel_origin_fp8(
308+
self,
309+
q: torch.Tensor,
310+
kv,
311+
infer_state: Deepseek2InferStateInfo,
312+
layer_weight: Deepseek2TransformerLayerWeight,
313+
out=None,
314+
) -> torch.Tensor:
315+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
316+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
317+
o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out
318+
if infer_state.use_dynamic_prompt_cache:
319+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
320+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
321+
context_attention_fwd_fp8(
322+
q_nope,
323+
q_rope,
324+
kv[:, :, : -self.qk_rope_head_dim],
325+
kv[:, :, -self.qk_rope_head_dim :],
326+
kv_scale,
327+
o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank),
328+
infer_state.b_req_idx,
329+
infer_state.b_start_loc,
330+
infer_state.b_seq_len,
331+
infer_state.b_ready_cache_len,
332+
infer_state.max_len_in_batch,
333+
infer_state.req_manager.req_to_token_indexs,
334+
self.softmax_scale,
335+
)
336+
else:
337+
context_attention_fwd_no_prompt_cache(
338+
q_nope,
339+
q_rope,
340+
kv[:, :, : -self.qk_rope_head_dim],
341+
kv[:, :, -self.qk_rope_head_dim :],
342+
o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank),
343+
infer_state.b_start_loc,
344+
infer_state.b_seq_len,
345+
infer_state.max_len_in_batch,
346+
self.softmax_scale,
347+
)
348+
349+
return o_tensor
350+
240351
def _token_gqa_decode_attention_flashdecoding(
241352
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
242353
):
@@ -279,6 +390,29 @@ def _token_gqa_decode_attention_flashdecoding(
279390
alloc_tensor_func=self.alloc_tensor,
280391
)
281392

393+
def _token_gqa_decode_attention_flashdecoding_fp8(
394+
self, q, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
395+
):
396+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
397+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
398+
399+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, :-2].view(torch.float8_e4m3fn)
400+
kv_scale = infer_state.mem_manager.kv_buffer[self.layer_num_][:, :, -2:].view(torch.bfloat16)
401+
return gqa_token_decode_attention_flash_decoding_fp8(
402+
q_nope,
403+
q_rope,
404+
kv[:, :, : -self.qk_rope_head_dim],
405+
kv[:, :, -self.qk_rope_head_dim :],
406+
kv_scale,
407+
infer_state,
408+
self.tp_q_head_num_,
409+
self.kv_lora_rank,
410+
self.qk_rope_head_dim,
411+
self.qk_nope_head_dim,
412+
self.softmax_scale,
413+
alloc_tensor_func=self.alloc_tensor,
414+
)
415+
282416
def _splitfuse_attention_kernel(
283417
self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None
284418
) -> torch.Tensor:
@@ -321,6 +455,20 @@ def _copy_kv_to_mem_cache_normal(self, buffer, mem_index, mem_manager):
321455
)
322456
return
323457

458+
def _copy_kv_to_mem_cache_fp8(self, buffer, mem_index, mem_manager):
459+
quant_method = vLLMFP8w8a8QuantizationMethod()
460+
quant, scale = quant_method.quantize_scaled_mm_fp8(buffer.reshape(-1, buffer.shape[-1]))
461+
destindex_copy_kv(
462+
quant.T.unsqueeze(1)[:, :, : self.kv_lora_rank].view(torch.uint8),
463+
quant.T.unsqueeze(1)[:, :, self.kv_lora_rank :].view(torch.uint8),
464+
mem_index,
465+
mem_manager.kv_buffer[self.layer_num_][:, :, : self.kv_lora_rank],
466+
mem_manager.kv_buffer[self.layer_num_][:, :, self.kv_lora_rank : -2],
467+
mem_manager.kv_buffer[self.layer_num_][:, :, -2:],
468+
scale.to(buffer.dtype).view(torch.uint8),
469+
)
470+
return
471+
324472
def _ffn_dp(
325473
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
326474
) -> torch.Tensor:

lightllm/models/deepseek2/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from lightllm.models.llama.model import LlamaTpPartModel
99
from lightllm.common.deepseek2_mem_manager import Deepseek2MemoryManager
10+
from lightllm.common.deepseek2_fp8kv_mem_manager import Deepseek2FP8KVMemoryManager
1011
from lightllm.utils.log_utils import init_logger
1112

1213

@@ -48,7 +49,10 @@ def _verify_params(self):
4849
return super()._verify_params()
4950

5051
def _init_mem_manager(self):
51-
self.mem_manager = Deepseek2MemoryManager(
52+
manager_class = Deepseek2MemoryManager
53+
if "triton_fp8kv" in self.mode:
54+
manager_class = Deepseek2FP8KVMemoryManager
55+
self.mem_manager = manager_class(
5256
self.max_total_token_num,
5357
dtype=self.data_type,
5458
head_num=1,

0 commit comments

Comments
 (0)