Skip to content

Commit 5525b77

Browse files
committed
fix
1 parent 572ce8d commit 5525b77

File tree

5 files changed

+36
-8
lines changed

5 files changed

+36
-8
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2828
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2929
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
30-
from lightllm.utils.envs_utils import set_model_init_status
30+
from lightllm.utils.envs_utils import set_model_init_status, enable_diverse_mode_gqa_decode_fast_kernel
3131
from lightllm.common.triton_utils.autotuner import Autotuner
3232
from lightllm.utils.infer_utils import post_empty_cache
3333

@@ -319,6 +319,15 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
319319
mode="constant",
320320
value=self.mem_manager.HOLD_TOKEN_MEMINDEX,
321321
)
322+
if enable_diverse_mode_gqa_decode_fast_kernel():
323+
if new_model_input.b_shared_seq_len is not None:
324+
new_model_input.b_shared_seq_len = F.pad(
325+
new_model_input.b_shared_seq_len, (0, padded_batch_size), mode="constant", value=0
326+
)
327+
if new_model_input.b_mark_shared_group is not None:
328+
new_model_input.b_mark_shared_group = F.pad(
329+
new_model_input.b_mark_shared_group, (0, padded_batch_size), mode="constant", value=1
330+
)
322331

323332
# 特殊模型,特殊模式的特殊变量的特殊 padding
324333
if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None:

lightllm/common/basemodel/batch_objs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass, field
33
from typing import Optional
44
from typing import List
5-
from lightllm.utils.envs_utils import get_env_start_args
5+
from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel
66

77

88
@dataclass
@@ -62,7 +62,7 @@ def to_cuda(self):
6262
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
6363
if self.b_prefill_start_loc is not None:
6464
self.b_prefill_start_loc = self.b_prefill_start_loc.cuda(non_blocking=True)
65-
if not self.is_prefill and get_env_start_args().diverse_mode:
65+
if not self.is_prefill and enable_diverse_mode_gqa_decode_fast_kernel():
6666
batch_size = len(self.b_req_idx)
6767
if self.b_mark_shared_group is None:
6868
self.b_mark_shared_group = torch.ones(size=(batch_size,), dtype=torch.int32, device="cuda")

lightllm/server/api_cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ def make_argument_parser() -> argparse.ArgumentParser:
204204
type=str,
205205
default=[],
206206
nargs="+",
207-
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding
207+
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_int8kv_flashdecoding | ppl_int8kv_flashdecoding_diverse
208+
| ppl_fp16 | triton_flashdecoding
208209
| triton_gqa_attention | triton_gqa_flashdecoding | triton_fp8kv | offline_calibration_fp8kv
209210
| export_fp8kv_calibration
210211
triton_flashdecoding mode is for long context, current support llama llama2 qwen;

lightllm/server/router/model_infer/mode_backend/generic_pre_process.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context
55
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
66
from lightllm.common.basemodel.batch_objs import ModelInput
7-
from lightllm.utils.envs_utils import get_env_start_args, get_diverse_max_batch_shared_group_size
7+
from lightllm.utils.envs_utils import (
8+
enable_diverse_mode_gqa_decode_fast_kernel,
9+
get_diverse_max_batch_shared_group_size,
10+
)
811

912

1013
def prepare_prefill_inputs(
@@ -93,7 +96,7 @@ def prepare_prefill_inputs(
9396

9497

9598
def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[InferReq]]:
96-
run_reqs = []
99+
run_reqs: List[InferReq] = []
97100
total_token_num = 0
98101
max_len_in_batch = 0
99102
b_req_idx = []
@@ -130,9 +133,18 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
130133

131134
b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu")
132135
b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu")
133-
b_shared_seq_len = torch.tensor(b_shared_seq_len, dtype=torch.int32, device="cpu")
134136
b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu")
135-
if get_env_start_args().diverse_mode:
137+
138+
if enable_diverse_mode_gqa_decode_fast_kernel():
139+
# b_shared_seq_len 和 b_mark_shared_group 只会在 diverse_mode 下的 decode 阶段真正被使用的参数,
140+
# 用于记录请求间的共享关系。
141+
# 举列说明:
142+
# b_shared_seq_len : [10, 10, 10, 11, 11, 11, 11]
143+
# b_mark_shared_group: [0, 0, 3, 0, 0, 0, 4]
144+
# b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于
145+
# 同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。某些模式可以利用这两个
146+
# 输入加速算子的运行。
147+
b_shared_seq_len = torch.tensor(b_shared_seq_len, dtype=torch.int32, device="cpu")
136148
b_mark_shared_group = []
137149
shared_nodes = [req.shared_kv_node for req in run_reqs]
138150
_current_group = []
@@ -159,6 +171,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In
159171
assert len(b_mark_shared_group) == len(run_reqs)
160172
b_mark_shared_group = torch.tensor(b_mark_shared_group, dtype=torch.int32, device="cpu")
161173
else:
174+
b_shared_seq_len = None
162175
b_mark_shared_group = None
163176

164177
# dynamic prompt cache 准备 token

lightllm/utils/envs_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,8 @@ def get_radix_tree_merge_update_delta() -> int:
199199
@lru_cache(maxsize=None)
200200
def get_diverse_max_batch_shared_group_size() -> int:
201201
return int(os.getenv("LIGHTLLM_MAX_BATCH_SHARED_GROUP_SIZE", 4))
202+
203+
204+
@lru_cache(maxsize=None)
205+
def enable_diverse_mode_gqa_decode_fast_kernel() -> bool:
206+
return get_env_start_args().diverse_mode and "ppl_int8kv_flashdecoding_diverse" in get_env_start_args().mode

0 commit comments

Comments
 (0)