Skip to content

Commit 8b3a55a

Browse files
authored
add Fa3 (#858)
Co-authored-by: baishihao <[email protected]>
1 parent 2b2d30f commit 8b3a55a

File tree

11 files changed

+288
-9
lines changed

11 files changed

+288
-9
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, kvargs):
7676
self._verify_must()
7777
self._verify_params()
7878
self._init_quant()
79+
self._init_inferstate_cls()
7980

8081
# 更连续的显存分配可以有更好的性能
8182
if self.max_total_token_num is None:
@@ -107,6 +108,9 @@ def _init_config(self):
107108
self.config["vocab_size"] = self.finetune_config.vocab_size
108109
return
109110

111+
def _init_inferstate_cls(self):
112+
pass
113+
110114
@final
111115
def _verify_must(self):
112116
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0

lightllm/common/basemodel/infer_struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,6 @@ def copy_for_cuda_graph(self, new_infer_state):
4949
for attr_name, attr_value in vars(new_infer_state).items():
5050
if isinstance(attr_value, torch.Tensor):
5151
attr_ = getattr(self, attr_name, None)
52-
if attr_ is not None:
52+
if attr_ is not None and attr_.data_ptr() != attr_value.data_ptr():
5353
attr_.copy_(attr_value, non_blocking=True)
5454
return
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import torch.distributed as dist
5+
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
6+
from lightllm.utils.dist_utils import get_current_device_id
7+
8+
9+
class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo):
10+
_shared_page_table_buffer = None
11+
12+
def __init__(self):
13+
super().__init__()
14+
15+
@classmethod
16+
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
17+
if cls._shared_page_table_buffer is None:
18+
cls._shared_page_table_buffer = [
19+
torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()),
20+
torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()),
21+
]
22+
return cls._shared_page_table_buffer
23+
24+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
25+
super().init_some_extra_state(model, input_ids)
26+
if self.is_prefill:
27+
self.cu_seqlens_q = torch.nn.functional.pad(
28+
torch.cumsum(self.b_seq_len - self.b_ready_cache_len, dim=0, dtype=torch.int32), (1, 0)
29+
)
30+
self.cu_seqlens_k = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
31+
self.page_table = torch.empty((self.batch_size, self.max_seq_len), dtype=torch.int32).to(input_ids.device)
32+
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
33+
else:
34+
# Meta information of flashattention for decoding
35+
self.cu_seqlens_q = torch.arange(0, self.batch_size + 1, dtype=torch.int32, device=input_ids.device)
36+
self.cu_seqlens_k = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
37+
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
38+
max_seq_len_k = b_seq_len_numpy.max()
39+
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
40+
page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(
41+
model.graph_max_batch_size, model.graph_max_len_in_batch
42+
)
43+
self.page_table = page_buffer[self.microbatch_index][
44+
: self.batch_size * model.graph_max_len_in_batch
45+
].reshape(self.batch_size, model.graph_max_len_in_batch)
46+
else:
47+
self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to(
48+
input_ids.device
49+
)
50+
51+
self.page_table[:, :max_seq_len_k].copy_(
52+
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
53+
)
54+
self.page_table[:, max_seq_len_k:].fill_(0)
55+
return

lightllm/models/deepseek2/infer_struct.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@ def __init__(self):
1212

1313
def init_some_extra_state(self, model, input_ids: torch.Tensor):
1414
super().init_some_extra_state(model, input_ids)
15-
# 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量
1615
if not self.is_prefill:
1716
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
1817
self.total_token_num_tensor = torch.sum(self.b_seq_len)
1918

2019
if self.is_prefill:
2120
self.b_kv_start_loc = self.b_seq_len.cumsum(dim=0) - self.b_seq_len
2221
self.max_value_in_b_seq_len = self.b_seq_len.max().item()
23-
2422
return

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@
2929
from lightllm.distributed.communication_op import all_gather, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
3030
from lightllm.utils.envs_utils import get_env_start_args
3131
from lightllm.utils.dist_utils import get_global_world_size
32+
from lightllm.utils.log_utils import init_logger
33+
34+
logger = init_logger(__name__)
35+
36+
try:
37+
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
38+
except:
39+
logger.warning("sgl_kernel is not installed, or the installed version does not support fa3!")
3240

3341

3442
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
@@ -93,7 +101,11 @@ def _bind_attention(self):
93101
)
94102
else:
95103
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
96-
if get_env_start_args().enable_flashinfer_decode:
104+
if get_env_start_args().enable_fa3:
105+
self._token_attention_kernel = partial(
106+
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self
107+
)
108+
elif get_env_start_args().enable_flashinfer_decode:
97109
self._token_attention_kernel = partial(
98110
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self
99111
)
@@ -112,7 +124,11 @@ def _bind_attention(self):
112124
Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC_fp8, self
113125
)
114126
else:
115-
if get_env_start_args().enable_flashinfer_prefill:
127+
if get_env_start_args().enable_fa3:
128+
self._context_attention_kernel = partial(
129+
Deepseek2TransformerLayerInfer._context_attention_flashattention_kernel_with_CC, self
130+
)
131+
elif get_env_start_args().enable_flashinfer_prefill:
116132
self._context_attention_kernel = partial(
117133
Deepseek2TransformerLayerInfer._context_attention_flashinfer_kernel_with_CC, self
118134
)
@@ -278,6 +294,30 @@ def _decompress_kv(
278294
k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
279295
return k_nope, k_rope, v
280296

297+
def _context_attention_flashattention_kernel_with_CC(
298+
self,
299+
q: torch.Tensor,
300+
kv,
301+
infer_state: Deepseek2FlashInferStateInfo,
302+
layer_weight: Deepseek2TransformerLayerWeight,
303+
out=None,
304+
) -> torch.Tensor:
305+
k_nope, k_rope, v = self._decompress_kv(kv, infer_state, layer_weight, False)
306+
k = torch.cat([k_nope, torch.repeat_interleave(k_rope, self.tp_q_head_num_, dim=-2)], dim=-1)
307+
o_tensor = flash_attn_varlen_func(
308+
q=q.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
309+
k=k.view(-1, self.tp_k_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
310+
v=v.view(-1, self.tp_v_head_num_, self.v_head_dim),
311+
cu_seqlens_q=infer_state.cu_seqlens_q,
312+
cu_seqlens_k=infer_state.cu_seqlens_k,
313+
max_seqlen_q=infer_state.q_max_seq_len,
314+
max_seqlen_k=infer_state.max_seq_len,
315+
softmax_scale=self.softmax_scale,
316+
causal=True,
317+
return_softmax_lse=False,
318+
)
319+
return o_tensor
320+
281321
def _context_attention_flashinfer_kernel_with_CC(
282322
self,
283323
q: torch.Tensor,
@@ -450,6 +490,35 @@ def _context_attention_kernel_origin_fp8(
450490

451491
return o_tensor
452492

493+
def _token_gqa_decode_attention_flashattention(
494+
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
495+
):
496+
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
497+
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
498+
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
499+
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim)
500+
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank)
501+
k_descale, v_descale = None, None
502+
o_tensor = flash_attn_with_kvcache(
503+
q=q_rope,
504+
k_cache=k_rope,
505+
v_cache=kv_nope,
506+
qv=q_nope,
507+
page_table=infer_state.page_table,
508+
cache_seqlens=infer_state.b_seq_len,
509+
cu_seqlens_q=infer_state.cu_seqlens_q,
510+
cu_seqlens_k_new=infer_state.cu_seqlens_k,
511+
max_seqlen_q=1,
512+
softmax_scale=self.softmax_scale,
513+
causal=True,
514+
window_size=(-1, -1),
515+
softcap=0.0,
516+
k_descale=k_descale,
517+
v_descale=v_descale,
518+
return_softmax_lse=False,
519+
)
520+
return o_tensor
521+
453522
def _token_gqa_decode_attention_flashinfer(
454523
self, q, infer_state: Deepseek2FlashInferStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
455524
):

lightllm/models/deepseek2/model.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
55
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
66
from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo
7+
from lightllm.models.deepseek2.flashattention_infer_struct import Deepseek2FlashAttentionStateInfo
78
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
89

910
from lightllm.models.llama.model import LlamaTpPartModel
@@ -62,11 +63,17 @@ def __init__(self, kvargs):
6263
self.enable_flashinfer = (
6364
get_env_start_args().enable_flashinfer_prefill or get_env_start_args().enable_flashinfer_decode
6465
)
65-
if self.enable_flashinfer:
66-
self.infer_state_class = Deepseek2FlashInferStateInfo
6766
super().__init__(kvargs)
6867
return
6968

69+
def _init_inferstate_cls(self):
70+
if get_env_start_args().enable_fa3:
71+
self.infer_state_class = Deepseek2FlashAttentionStateInfo
72+
elif self.enable_flashinfer:
73+
self.infer_state_class = Deepseek2FlashInferStateInfo
74+
if self.enable_flashinfer:
75+
self.flashinfer_extra_state = FlashInferStateExtraInfo(self)
76+
7077
def _init_some_value(self):
7178
super()._init_some_value()
7279
self.tp_k_head_num_ = 1
@@ -77,8 +84,6 @@ def _init_some_value(self):
7784
self.q_lora_rank = self.config["q_lora_rank"]
7885
self.kv_lora_rank = self.config["kv_lora_rank"]
7986
self.head_dim_ = self.kv_lora_rank + self.qk_rope_head_dim
80-
if self.enable_flashinfer:
81-
self.flashinfer_extra_state = FlashInferStateExtraInfo(self)
8287

8388
def _init_custom(self):
8489
self._init_to_get_yarn_rotary()
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
import torch.distributed as dist
5+
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
6+
from lightllm.utils.envs_utils import get_env_start_args
7+
from lightllm.utils.dist_utils import get_current_device_id
8+
from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index
9+
10+
11+
class FlashAttentionStateInfo(LlamaInferStateInfo):
12+
_shared_page_table_buffer = None
13+
14+
def __init__(self):
15+
super().__init__()
16+
17+
@classmethod
18+
def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int):
19+
if cls._shared_page_table_buffer is None:
20+
cls._shared_page_table_buffer = [
21+
torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()),
22+
torch.empty(graph_max_batch_size * max_seq_len, dtype=torch.int32).to(get_current_device_id()),
23+
]
24+
return cls._shared_page_table_buffer
25+
26+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
27+
super().init_some_extra_state(model, input_ids)
28+
if self.is_prefill:
29+
self.cu_seqlens_q = torch.nn.functional.pad(
30+
torch.cumsum(self.b_seq_len - self.b_ready_cache_len, dim=0, dtype=torch.int32), (1, 0)
31+
)
32+
self.cu_seqlens_k = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
33+
self.page_table = torch.empty((self.batch_size, self.max_seq_len), dtype=torch.int32).to(input_ids.device)
34+
self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len])
35+
else:
36+
# Meta information of flashattention for decoding
37+
self.cu_seqlens_q = torch.arange(0, self.batch_size + 1, dtype=torch.int32, device=input_ids.device)
38+
self.cu_seqlens_k = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
39+
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
40+
max_seq_len_k = b_seq_len_numpy.max()
41+
if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch:
42+
page_buffer = FlashAttentionStateInfo.get_page_table_buffer(
43+
model.graph_max_batch_size, model.graph_max_len_in_batch
44+
)
45+
self.page_table = page_buffer[self.microbatch_index][
46+
: self.batch_size * model.graph_max_len_in_batch
47+
].reshape(self.batch_size, model.graph_max_len_in_batch)
48+
else:
49+
self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to(
50+
input_ids.device
51+
)
52+
53+
self.page_table[:, :max_seq_len_k].copy_(
54+
model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k]
55+
)
56+
self.page_table[:, max_seq_len_k:].fill_(0)
57+
return

lightllm/models/llama/infer_struct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
1515
b_seq_len_numpy = self.b_seq_len.cpu().numpy()
1616
self.max_seq_len = b_seq_len_numpy.max()
1717
b_ready_cache_len_numpy = self.b_ready_cache_len.cpu().numpy()
18+
self.q_max_seq_len = (b_seq_len_numpy - b_ready_cache_len_numpy).max()
1819
position_ids = torch.from_numpy(
1920
np.concatenate(
2021
[np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))],

0 commit comments

Comments
 (0)