Skip to content

Commit 8d04a25

Browse files
Chunked prefill (#717)
Co-authored-by: baishihao <baishihao@sensetime.com> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 65e2747 commit 8d04a25

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+299
-802
lines changed

docs/CN/source/models/test.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,12 @@ internlm2-1_8b
135135
136136
$ python -m lightllm.server.api_server
137137
$ --model_dir ~/models/internlm2-1_8b \
138-
$ --splitfuse_mode \
138+
$ --enable_chunked_prefill \
139139
$ --trust_remote_code
140140
141141
.. tip::
142142

143-
``--splitfuse_mode`` 表示使用splitfuse进行加速
143+
``--enable_chunked_prefill`` 表示使用chunkedprefill进行长文本推理
144144

145145

146146
**测试服务**

docs/EN/source/models/test.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,12 @@ internlm2-1_8b
213213
.. code-block:: console
214214
215215
$ python -m lightllm.server.api_server --model_dir ~/models/internlm2-1_8b \
216-
$ --splitfuse_mode \
216+
$ --enable_chunked_prefill \
217217
$ --trust_remote_code
218218
219219
.. tip::
220220

221-
``--splitfuse_mode`` Indicates the use of splitfuse for acceleration.
221+
``--enable_chunked_prefill`` Indicates the use of chunkedprefill for long context.
222222

223223

224224
**Test Server**

lightllm/common/basemodel/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .layer_infer.template.pre_layer_infer_template import PreLayerInferTpl
1010
from .layer_infer.template.post_layer_infer_template import PostLayerInferTpl
1111
from .infer_struct import InferStateInfo
12-
from .splitfuse_infer_struct import SplitFuseInferStateInfo
1312
from .basemodel import TpPartBaseModel
1413

1514

@@ -26,5 +25,4 @@
2625
"TpPartBaseModel",
2726
"PreLayerInferTpl",
2827
"PostLayerInferTpl",
29-
"SplitFuseInferStateInfo",
3028
]

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@
77

88
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
99
from lightllm.common.basemodel.infer_struct import InferStateInfo
10-
from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo
1110
from lightllm.common.mem_manager import MemoryManager
1211
from lightllm.common.req_manager import ReqManager
1312
from lightllm.common.infer_utils import init_req_to_token_indexes
1413
from lightllm.common.build_utils import repair_config
1514
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
16-
from lightllm.common.basemodel.triton_kernel.splitfuse_copy_kv_index_to_req import splitfuse_copy_kv_index_to_req
1715
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1816
from lightllm.common.basemodel.cuda_graph import CudaGraph
1917
from lightllm.common.quantization import Quantcfg
@@ -36,7 +34,6 @@ class TpPartBaseModel:
3634

3735
# infer state class
3836
infer_state_class = InferStateInfo
39-
splitfuse_infer_state_class = SplitFuseInferStateInfo
4037

4138
def __init__(self, kvargs):
4239
self.run_mode = kvargs["run_mode"]
@@ -57,6 +54,8 @@ def __init__(self, kvargs):
5754
self.return_all_prompt_logics = kvargs.get("return_all_prompt_logics", False)
5855
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
5956
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)
57+
enable_chunked_prefill = kvargs.get("enable_chunked_prefill", False) # chunked prefill is default on.
58+
self.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache or enable_chunked_prefill
6059
self.data_type = kvargs.get("data_type", "float16")
6160
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
6261
self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192)
@@ -368,81 +367,6 @@ def _decode(
368367
predict_logics = self._token_forward(input_ids, infer_state)
369368
return predict_logics
370369

371-
@torch.no_grad()
372-
def splitfuse_forward(
373-
self,
374-
input_ids,
375-
mem_indexes,
376-
decode_req_num,
377-
decode_total_token_num,
378-
decode_b_req_idx: torch.Tensor,
379-
decode_b_start_loc: torch.Tensor,
380-
decode_b_seq_len: torch.Tensor,
381-
decode_max_len_in_batch,
382-
prefill_req_num,
383-
prefill_b_req_idx: torch.Tensor,
384-
prefill_b_split_start_loc: torch.Tensor,
385-
prefill_b_split_ready_cache_len: torch.Tensor,
386-
prefill_max_split_seq_len_in_batch,
387-
prefill_b_seq_len: torch.Tensor,
388-
):
389-
390-
infer_state = self.splitfuse_infer_state_class()
391-
infer_state.use_dynamic_prompt_cache = self.use_dynamic_prompt_cache
392-
infer_state.batch_size = decode_req_num + prefill_req_num
393-
394-
infer_state.decode_req_num = decode_req_num
395-
infer_state.decode_total_token_num = decode_total_token_num
396-
infer_state.decode_b_req_idx = decode_b_req_idx
397-
infer_state.decode_b_start_loc = decode_b_start_loc
398-
infer_state.decode_b_seq_len = decode_b_seq_len
399-
infer_state.decode_max_len_in_batch = decode_max_len_in_batch
400-
401-
infer_state.prefill_req_num = prefill_req_num
402-
infer_state.prefill_b_req_idx = prefill_b_req_idx
403-
infer_state.prefill_b_split_start_loc = prefill_b_split_start_loc
404-
infer_state.prefill_b_split_ready_cache_len = prefill_b_split_ready_cache_len
405-
infer_state.prefill_max_split_seq_len_in_batch = prefill_max_split_seq_len_in_batch
406-
infer_state.prefill_b_seq_len = prefill_b_seq_len
407-
# infer_state.event = [torch.cuda.Event() for _ in range(self.layers_num)]
408-
409-
infer_state.mem_manager = self.mem_manager
410-
infer_state.req_manager = self.req_manager
411-
412-
alloc_size = len(input_ids)
413-
infer_state.mem_is_contiguous = False
414-
infer_state.mem_index = mem_indexes
415-
infer_state.kv_buffer = torch.empty(
416-
(alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
417-
dtype=self.data_type,
418-
device="cuda",
419-
)
420-
421-
# decode 部分
422-
if decode_req_num != 0:
423-
copy_kv_index_to_req(
424-
self.req_manager.req_to_token_indexs,
425-
decode_b_req_idx,
426-
decode_b_seq_len,
427-
infer_state.mem_index[0:decode_req_num],
428-
)
429-
430-
# split prefill 部分
431-
if prefill_req_num != 0:
432-
splitfuse_copy_kv_index_to_req(
433-
self.req_manager.req_to_token_indexs,
434-
prefill_b_req_idx,
435-
prefill_b_split_ready_cache_len,
436-
prefill_b_seq_len,
437-
infer_state.mem_index[decode_req_num:],
438-
)
439-
440-
infer_state.init_some_extra_state(self, input_ids)
441-
infer_state.create_inner_decode_infer_status()
442-
infer_state.create_inner_prefill_infer_status()
443-
predict_logics = self._splitfuse_forward(input_ids, infer_state)
444-
return predict_logics
445-
446370
@final
447371
def _context_forward(self, input_ids, infer_state: InferStateInfo):
448372
g_cache_manager.cache_env_in()
@@ -469,17 +393,6 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):
469393
g_cache_manager.cache_env_out()
470394
return predict_logics
471395

472-
@final
473-
def _splitfuse_forward(self, input_ids, infer_state: SplitFuseInferStateInfo):
474-
g_cache_manager.cache_env_in()
475-
cuda_input_ids = input_ids
476-
input_embs = self.pre_infer.splitfuse_forward(cuda_input_ids, infer_state, self.pre_post_weight)
477-
for i in range(0, self.layers_num):
478-
input_embs = self.layers_infer[i].splitfuse_forward(input_embs, infer_state, self.trans_layers_weight[i])
479-
predict_logics = self.post_infer.splitfuse_forward(input_embs, infer_state, self.pre_post_weight)
480-
g_cache_manager.cache_env_out()
481-
return predict_logics
482-
483396
@final
484397
@torch.no_grad()
485398
def _check_max_len_infer(self):

lightllm/common/basemodel/infer_struct.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __init__(self):
3030
self.mem_end = None
3131
self.kv_buffer = None
3232

33-
self.is_splitfuse = False
3433
self.is_token_healing = False
3534
self.return_all_prompt_logics = False
3635
self.use_dynamic_prompt_cache = False

lightllm/common/basemodel/layer_infer/base_layer_infer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from typing import Dict, Iterable, Literal, Tuple, Union, List
33
from lightllm.common.basemodel.infer_struct import InferStateInfo
4-
from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo
54
from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight
65
from .cache_tensor_manager import g_cache_manager
76

@@ -16,9 +15,6 @@ def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight:
1615
def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BaseLayerWeight):
1716
raise Exception("need to impl")
1817

19-
def splitfuse_forward(self, input_ids, infer_state: SplitFuseInferStateInfo, layer_weight: BaseLayerWeight):
20-
raise Exception("need to impl")
21-
2218
def alloc_tensor(
2319
self,
2420
shape: Union[torch.Size, Iterable[int]],

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from lightllm.utils.infer_utils import mark_cost_time
99

1010
from ...infer_struct import InferStateInfo
11-
from ...splitfuse_infer_struct import SplitFuseInferStateInfo
1211
from ..transformer_layer_infer import TransformerLayerInfer
1312

1413

@@ -69,11 +68,6 @@ def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_we
6968
def _token_attention_kernel(self, q, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor:
7069
raise Exception("need to impl")
7170

72-
def _splitfuse_attention_kernel(
73-
self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None
74-
) -> torch.Tensor:
75-
raise Exception("need to impl")
76-
7771
def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
7872
raise Exception("need to impl")
7973

@@ -118,25 +112,6 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
118112
infer_state._ffn_out = ffn_out
119113
return
120114

121-
def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateInfo, layer_weight):
122-
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
123-
q, cache_kv = self._get_qkv(input_embding, cache_kv, infer_state, layer_weight)
124-
self._post_cache_kv(cache_kv, infer_state, layer_weight)
125-
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
126-
q = None
127-
o = self._get_o(o, infer_state, layer_weight)
128-
if self.world_size_ > 1:
129-
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
130-
infer_state._attn_out = o
131-
return
132-
133-
def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight):
134-
ffn_out = self._ffn(input_embdings, infer_state, layer_weight)
135-
if self.world_size_ > 1:
136-
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
137-
infer_state._ffn_out = ffn_out
138-
return
139-
140115
def _cohere_residual(self, input_embdings, infer_state: InferStateInfo):
141116
# emb_addr = input_embdings.data_ptr()
142117
# attn_out_addr = infer_state._attn_out.data_ptr()
@@ -161,10 +136,3 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
161136
self._token_ffn(input1, infer_state, layer_weight)
162137
self._cohere_residual(input_embdings, infer_state)
163138
return input_embdings
164-
165-
def splitfuse_forward(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight):
166-
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
167-
self._splitfuse_attention(input1, infer_state, layer_weight=layer_weight)
168-
self._splitfuse_ffn(input1, infer_state, layer_weight)
169-
self._cohere_residual(input_embdings, infer_state)
170-
return input_embdings

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch.distributed as dist
44
from ..transformer_layer_infer import TransformerLayerInfer
55
from ...infer_struct import InferStateInfo
6-
from ...splitfuse_infer_struct import SplitFuseInferStateInfo
76
from lightllm.utils.infer_utils import mark_cost_time
87
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv
98
from typing import Tuple
@@ -61,11 +60,6 @@ def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_we
6160
def _token_attention_kernel(self, q, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor:
6261
raise Exception("need to impl")
6362

64-
def _splitfuse_attention_kernel(
65-
self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None
66-
) -> torch.Tensor:
67-
raise Exception("need to impl")
68-
6963
def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
7064
raise Exception("need to impl")
7165

@@ -118,29 +112,6 @@ def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight):
118112
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
119113
return
120114

121-
def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateInfo, layer_weight):
122-
input1 = self._att_norm(input_embding, infer_state, layer_weight)
123-
cache_kv = self._pre_cache_kv(infer_state, layer_weight)
124-
q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight)
125-
input1 = None
126-
self._post_cache_kv(cache_kv, infer_state, layer_weight)
127-
o = self._splitfuse_attention_kernel(q, infer_state, layer_weight)
128-
q = None
129-
o = self._get_o(o, infer_state, layer_weight)
130-
if self.world_size_ > 1:
131-
dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False)
132-
input_embding.add_(o.view(-1, self.embed_dim_))
133-
return
134-
135-
def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight):
136-
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
137-
ffn_out = self._ffn(input1, infer_state, layer_weight)
138-
input1 = None
139-
if self.world_size_ > 1:
140-
dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False)
141-
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
142-
return
143-
144115
def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
145116
self._context_attention(input_embdings, infer_state, layer_weight=layer_weight)
146117
self._context_ffn(input_embdings, infer_state, layer_weight)
@@ -150,8 +121,3 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
150121
self._token_attention(input_embdings, infer_state, layer_weight=layer_weight)
151122
self._token_ffn(input_embdings, infer_state, layer_weight)
152123
return input_embdings
153-
154-
def splitfuse_forward(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight):
155-
self._splitfuse_attention(input_embdings, infer_state, layer_weight=layer_weight)
156-
self._splitfuse_ffn(input_embdings, infer_state, layer_weight)
157-
return input_embdings

0 commit comments

Comments
 (0)