Skip to content

Commit 87dc04b

Browse files
niushengxiaoblueswhen
authored andcommitted
feat: clean code
1 parent 4c99584 commit 87dc04b

File tree

18 files changed

+209
-1128
lines changed

18 files changed

+209
-1128
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __init__(self, kvargs):
6161
self.finetune_config = kvargs.get("finetune_config", None)
6262
self.max_req_num = kvargs.get("max_req_num", 1000)
6363
self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5)
64-
# 一个可选的钩子函数,会在模型 warmup 之前被调用,用于做外部初始化等待(如 CPU KV Cache 注册完成)
65-
self._pre_warmup_hook = kvargs.get("pre_warmup_hook", None)
64+
# 用于做外部初始化等待(如 CPU KV Cache 注册完成)
65+
self.waiting_hook = kvargs.get("waiting_hook", None)
6666
# is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效
6767
# 主要是在prefill阶段返回多少个token的用于后续处理相关。
6868
self.is_token_healing = kvargs.get("is_token_healing", False)
@@ -108,19 +108,19 @@ def __init__(self, kvargs):
108108
self._init_inferstate_cls()
109109
self._autotune_warmup()
110110
self._init_padded_req()
111-
# 在进入 autotune warmup 之前执行可选的预热钩子(例如等待 CPU KV Cache 注册完成)
112-
if callable(self._pre_warmup_hook):
113-
try:
114-
self._pre_warmup_hook()
115-
except Exception as e:
116-
logger.exception(f"pre_warmup_hook failed: {e}")
117-
raise
111+
# wait必须在init cudagraph之前,避免错误捕获
112+
self._run_waiting_hook()
118113
self._init_cudagraph()
119114
self._check_max_len_infer()
120115
torch.cuda.empty_cache()
121116
set_model_init_status(True)
122117
return
123118

119+
def _run_waiting_hook(self):
120+
if self.waiting_hook is not None:
121+
self.waiting_hook()
122+
return
123+
124124
def _init_config(self):
125125
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
126126
self.config = json.load(json_file)

lightllm/common/basemodel/triton_kernel/kv_cache_offload.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import triton
44
import triton.language as tl
5+
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
56

67

78
@triton.jit
@@ -72,16 +73,13 @@ def offload_gpu_kv_to_cpu(
7273
):
7374
"""
7475
this function is used to offload GPU KV cache to CPU KV cache.
75-
Supports tensor parallelism (TP > 1).
7676
Args:
7777
token_indexes: (token_num,)
7878
gpu_kv_cache: (layer_num, token_num, head_num, head_dim)
7979
cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim)
8080
page_indexes: (page_num,)
8181
page_readies: (page_num,)
8282
"""
83-
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
84-
8583
token_block_size = cpu_kv_cache.shape[2]
8684
token_num = page_indexes.shape[0] * token_block_size
8785
assert token_indexes.shape[0] >= token_num
@@ -92,9 +90,15 @@ def offload_gpu_kv_to_cpu(
9290

9391
# Calculate head offset for tensor parallelism
9492
tp_rank = get_current_rank_in_dp()
93+
tp_num = get_dp_world_size()
9594
gpu_heads = gpu_kv_cache.shape[2]
9695
gpu_head_dim = gpu_kv_cache.shape[3]
97-
cpu_head_offset = tp_rank * gpu_heads * gpu_head_dim
96+
cpu_heads = cpu_kv_cache.shape[3]
97+
factor = (tp_num * gpu_heads) // cpu_heads
98+
cpu_head_offset = (tp_rank // factor) * gpu_heads * gpu_head_dim
99+
if tp_rank % factor != 0:
100+
# redundant kv does not need to offload
101+
return
98102

99103
grid = (page_num,)
100104
num_warps = 4
@@ -142,7 +146,6 @@ def _load_cpu_cache_to_gpu(
142146
page_indexes_ptr,
143147
layer_num,
144148
head_all_dim,
145-
all_move_token_num,
146149
cpu_head_offset,
147150
BLOCK_HEAD_ALL_DIM: tl.constexpr,
148151
TOKEN_BLOCK: tl.constexpr,
@@ -152,17 +155,11 @@ def _load_cpu_cache_to_gpu(
152155
if cpu_page_index == -1:
153156
return
154157

155-
gpu_stride0 = tl.cast(gpu_stride0, dtype=tl.int64)
156-
padded_size = TOKEN_BLOCK * tl.num_programs(0) - all_move_token_num
157-
head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM)
158158
token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
159-
token_range = token_range - padded_size
160-
161-
token_mask = token_range >= 0
159+
token_indexes = tl.load(token_indexes_ptr + token_range).to(tl.int64)
160+
head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM)
162161
head_dim_mask = head_all_dim_range < head_all_dim
163162

164-
token_indexes = tl.load(token_indexes_ptr + token_range, mask=token_mask, other=0).to(tl.int64)
165-
166163
cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64)
167164
for layer_index in range(layer_num):
168165
cpu_ptr = (
@@ -176,14 +173,14 @@ def _load_cpu_cache_to_gpu(
176173

177174
gpu_ptr = (
178175
gpu_kv_cache_ptr
179-
+ layer_index * gpu_stride0
176+
+ layer_index.to(tl.int64) * gpu_stride0
180177
+ token_indexes[:, None] * gpu_stride1
181178
+ head_all_dim_range[None, :]
182179
)
183180
tl.store(
184181
gpu_ptr,
185182
cpu_data,
186-
mask=token_mask[:, None] & head_dim_mask[None, :],
183+
mask=head_dim_mask[None, :],
187184
)
188185
return
189186

@@ -196,27 +193,28 @@ def load_cpu_kv_to_gpu(
196193
page_indexes: torch.Tensor,
197194
):
198195
"""
199-
this function is used to load CPU KV cache to GPU KV cache.
200-
Supports tensor parallelism (TP > 1).
196+
this function is used to offload GPU KV cache to CPU KV cache.
201197
Args:
202198
mem_indexes: (token_num,)
203199
gpu_kv_cache: (layer_num, token_num, head_num, head_dim)
204200
cpu_kv_cache: (page_num, layer_num, token_block_size, head_num, head_dim)
205201
page_indexes: (page_num,)
206202
"""
207-
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
208-
209203
token_block_size = cpu_kv_cache.shape[2]
210204
token_num = page_indexes.shape[0] * token_block_size
211205
assert mem_indexes.shape[0] >= token_num
212206
page_num = page_indexes.shape[0]
207+
assert len(mem_indexes) == page_num * token_block_size
213208
BLOCK_HEAD_ALL_DIM = triton.next_power_of_2(gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2])
214209

215210
# Calculate head offset for tensor parallelism
216211
tp_rank = get_current_rank_in_dp()
212+
tp_num = get_dp_world_size()
217213
gpu_heads = gpu_kv_cache.shape[2]
218214
gpu_head_dim = gpu_kv_cache.shape[3]
219-
cpu_head_offset = tp_rank * gpu_heads * gpu_head_dim
215+
cpu_heads = cpu_kv_cache.shape[3]
216+
factor = (tp_num * gpu_heads) // cpu_heads
217+
cpu_head_offset = (tp_rank // factor) * gpu_heads * gpu_head_dim
220218

221219
grid = (page_num,)
222220
num_warps = 1
@@ -237,7 +235,6 @@ def load_cpu_kv_to_gpu(
237235
page_indexes_ptr=page_indexes,
238236
layer_num=gpu_kv_cache.shape[0],
239237
head_all_dim=gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2],
240-
all_move_token_num=len(mem_indexes),
241238
cpu_head_offset=cpu_head_offset,
242239
BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
243240
TOKEN_BLOCK=token_block_size,

lightllm/server/api_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
508508
parser.add_argument(
509509
"--enable_cpu_cache",
510510
action="store_true",
511-
help="""enable cpu cache to store kv cache.""",
511+
help="""enable cpu cache to store kv cache. prefer to use hugepages for better performance.""",
512512
)
513513
parser.add_argument(
514514
"--cpu_cache_storage_size",
@@ -519,7 +519,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
519519
parser.add_argument(
520520
"--cpu_cache_token_page_size",
521521
type=int,
522-
default=256,
522+
default=64,
523523
help="""The token page size of cpu cache""",
524524
)
525525
parser.add_argument("--enable_disk_cache", action="store_true", help="""enable disk cache to store kv cache.""")

lightllm/server/core/objs/atomic_array_lock.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import atomics
33
from multiprocessing import shared_memory
44
from lightllm.utils.log_utils import init_logger
5+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
56

67
logger = init_logger(__name__)
78

@@ -26,6 +27,7 @@ def __init__(self, lock_name: str, lock_num: int):
2627
def _init_shm(self):
2728
try:
2829
shm = shared_memory.SharedMemory(name=self.lock_name, create=True, size=self.dest_size)
30+
register_posix_shm_for_cleanup(self.lock_name)
2931
logger.info(f"create lock shm {self.lock_name}")
3032
except:
3133
shm = shared_memory.SharedMemory(name=self.lock_name, create=False, size=self.dest_size)

lightllm/server/core/objs/atomic_lock.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from multiprocessing import shared_memory
44
from lightllm.utils.log_utils import init_logger
5+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
56

67
logger = init_logger(__name__)
78

@@ -25,6 +26,7 @@ def _init_shm(self):
2526
try:
2627
shm = shared_memory.SharedMemory(name=self.lock_name, create=True, size=self.dest_size)
2728
logger.info(f"create lock shm {self.lock_name}")
29+
register_posix_shm_for_cleanup(self.lock_name)
2830
except:
2931
shm = shared_memory.SharedMemory(name=self.lock_name, create=False, size=self.dest_size)
3032
logger.info(f"link lock shm {self.lock_name}")

lightllm/server/core/objs/rpc_shm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List
66
from lightllm.utils.envs_utils import get_unique_server_name
77
from lightllm.utils.log_utils import init_logger
8+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
89

910
logger = init_logger(__name__)
1011

@@ -20,6 +21,7 @@ def __init__(self):
2021
def create_or_link_shm(self):
2122
try:
2223
shm = shared_memory.SharedMemory(name=self.name, create=True, size=LIGHTLLM_RPC_BYTE_SIZE)
24+
register_posix_shm_for_cleanup(self.name)
2325
except:
2426
shm = shared_memory.SharedMemory(name=self.name, create=False, size=LIGHTLLM_RPC_BYTE_SIZE)
2527

@@ -57,6 +59,7 @@ def __init__(self):
5759
def create_or_link_shm(self):
5860
try:
5961
shm = shared_memory.SharedMemory(name=self.name, create=True, size=LIGHTLLM_RPC_RESULT_BYTE_SIZE)
62+
register_posix_shm_for_cleanup(self.name)
6063
except:
6164
shm = shared_memory.SharedMemory(name=self.name, create=False, size=LIGHTLLM_RPC_RESULT_BYTE_SIZE)
6265

@@ -99,6 +102,7 @@ def __init__(self, world_size):
99102
def create_or_link_shm(self):
100103
try:
101104
shm = shared_memory.SharedMemory(name=self.name, create=True, size=self.dest_size)
105+
register_posix_shm_for_cleanup(self.name)
102106
except:
103107
shm = shared_memory.SharedMemory(name=self.name, create=False, size=self.dest_size)
104108

lightllm/server/core/objs/shm_array.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from multiprocessing import shared_memory
3+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
34
from lightllm.utils.log_utils import init_logger
45

56
logger = init_logger(__name__)
@@ -18,6 +19,7 @@ def __init__(self, name, shape, dtype):
1819
def create_shm(self):
1920
try:
2021
shm = shared_memory.SharedMemory(name=self.name, create=True, size=self.dest_size)
22+
register_posix_shm_for_cleanup(self.name)
2123
except:
2224
shm = shared_memory.SharedMemory(name=self.name, create=False, size=self.dest_size)
2325

@@ -28,6 +30,7 @@ def create_shm(self):
2830
try:
2931
shm = shared_memory.SharedMemory(name=self.name, create=True, size=self.dest_size)
3032
logger.info(f"create shm {self.name}")
33+
register_posix_shm_for_cleanup(self.name)
3134
except:
3235
shm = shared_memory.SharedMemory(name=self.name, create=False, size=self.dest_size)
3336
logger.info(f"link shm {self.name}")

lightllm/server/core/objs/shm_req_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from lightllm.utils.envs_utils import get_unique_server_name
44
from multiprocessing import shared_memory
5+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
56
from lightllm.utils.log_utils import init_logger
67
from .req import Req, ChunkedPrefillReq, TokenHealingReq
78
from .shm_array import ShmArray
@@ -53,6 +54,7 @@ def _init_reqs_shm(self):
5354
shm_name = f"{get_unique_server_name()}_req_shm_total"
5455
try:
5556
shm = shared_memory.SharedMemory(name=shm_name, create=True, size=self.req_shm_byte_size)
57+
register_posix_shm_for_cleanup(shm_name)
5658
logger.info(f"create lock shm {shm_name}")
5759
except:
5860
shm = shared_memory.SharedMemory(name=shm_name, create=False, size=self.req_shm_byte_size)

lightllm/server/multi_level_kv_cache/cpu_cache_client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
attach_shm_kv_cache_ptr,
1313
register_shm_ptr_to_pin,
1414
)
15-
from lightllm.utils.infer_utils import mark_start, mark_end
1615

1716
logger = init_logger(__name__)
1817

@@ -31,9 +30,9 @@ def __init__(self, init_shm_data: bool):
3130
self._create_cpu_status_list(init_shm_data)
3231
if init_shm_data:
3332
self._create_shm_cpu_kv_cache()
34-
self.pin_reg_handle = None
33+
self.attach_shm_handle = None
3534
else:
36-
self.pin_reg_handle = self._attach_shm_cpu_kv_cache()
35+
self.attach_shm_handle = self._attach_shm_cpu_kv_cache()
3736
return
3837

3938
def get_one_empty_page(self, hash_key: int, disk_offload_enable: bool) -> Optional[int]:
@@ -215,7 +214,6 @@ def _create_shm_cpu_kv_cache(self):
215214

216215
def _attach_shm_cpu_kv_cache(self):
217216
shm_ptr = attach_shm_kv_cache_ptr()
218-
mark_start("blueswhen1")
219217
handle = register_shm_ptr_to_pin(shm_ptr=shm_ptr, size=self.kv_cache_tensor_meta.calcu_size())
220218
numpy_array = np.frombuffer(
221219
memoryview((ctypes.c_uint8 * self.kv_cache_tensor_meta.calcu_size()).from_address(shm_ptr)), dtype=np.uint8

lightllm/server/multi_level_kv_cache/manager.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,6 @@ def __init__(
3131
self.send_to_router = context.socket(zmq.PUSH)
3232
self.send_to_router.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}")
3333
logger.info(f"send_to_router sendhwm {self.send_to_router.getsockopt(zmq.SNDHWM)}")
34-
35-
# 自动注册共享内存清理
36-
try:
37-
from lightllm.utils.auto_shm_cleanup import auto_register_cpu_cache
38-
39-
auto_register_cpu_cache()
40-
except Exception as e:
41-
logger.warning(f"Failed to register auto shm cleanup: {e}")
42-
4334
self.cpu_cache_client = CpuKvCacheClient(init_shm_data=True)
4435
self.shm_req_manager = ShmReqManager()
4536
# 控制同时进行cpu cache 匹配操作的数量。

0 commit comments

Comments
 (0)