Skip to content

Commit b01f213

Browse files
author
wangzaijun
committed
rebase1
1 parent 0a9fad9 commit b01f213

File tree

17 files changed

+539
-156
lines changed

17 files changed

+539
-156
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self, kvargs):
6161
self.finetune_config = kvargs.get("finetune_config", None)
6262
self.max_req_num = kvargs.get("max_req_num", 1000)
6363
self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5)
64+
# 用于做外部初始化等待(如 CPU KV Cache 注册完成)
65+
self.waiting_hook = kvargs.get("waiting_hook", None)
6466
# is_token_healing 和 return_all_prompt_logics 是有排斥关系的两个模式,只能单独有一个生效
6567
# 主要是在prefill阶段返回多少个token的用于后续处理相关。
6668
self.is_token_healing = kvargs.get("is_token_healing", False)
@@ -110,12 +112,19 @@ def __init__(self, kvargs):
110112
self._init_inferstate_cls()
111113
self._autotune_warmup()
112114
self._init_padded_req()
115+
# wait必须在init cudagraph之前,避免错误捕获
116+
self._run_waiting_hook()
113117
self._init_cudagraph()
114118
self._check_max_len_infer()
115119
torch.cuda.empty_cache()
116120
set_model_init_status(True)
117121
return
118122

123+
def _run_waiting_hook(self):
124+
if self.waiting_hook is not None:
125+
self.waiting_hook()
126+
return
127+
119128
def _init_config(self):
120129
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
121130
self.config = json.load(json_file)

lightllm/common/basemodel/triton_kernel/kv_cache_offload.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import triton
44
import triton.language as tl
5+
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_dp_world_size
56

67

78
@triton.jit
@@ -22,6 +23,7 @@ def _offload_gpu_kv_to_cpu(
2223
page_readies_ptr,
2324
layer_num,
2425
head_all_dim,
26+
cpu_head_offset,
2527
BLOCK_HEAD_ALL_DIM: tl.constexpr,
2628
TOKEN_BLOCK: tl.constexpr,
2729
):
@@ -38,12 +40,10 @@ def _offload_gpu_kv_to_cpu(
3840
token_indexes = tl.load(token_indexes_ptr + token_range).to(tl.int64)
3941
head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM)
4042

41-
gpu_stride0 = tl.cast(gpu_stride0, dtype=tl.int64)
42-
4343
for layer_index in range(layer_num):
4444
gpu_ptr = (
4545
gpu_kv_cache_ptr
46-
+ layer_index * gpu_stride0
46+
+ layer_index.to(tl.int64) * gpu_stride0
4747
+ token_indexes[:, None] * gpu_stride1
4848
+ head_all_dim_range[None, :]
4949
)
@@ -53,7 +53,7 @@ def _offload_gpu_kv_to_cpu(
5353
+ cpu_page_index * cpu_stride0
5454
+ layer_index * cpu_stride1
5555
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
56-
+ head_all_dim_range[None, :]
56+
+ (cpu_head_offset + head_all_dim_range[None, :])
5757
)
5858
tl.store(
5959
cpu_ptr,
@@ -88,6 +88,18 @@ def offload_gpu_kv_to_cpu(
8888
head_all_dim = gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2]
8989
BLOCK_HEAD_ALL_DIM = triton.next_power_of_2(gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2])
9090

91+
# Calculate head offset for tensor parallelism
92+
tp_rank = get_current_rank_in_dp()
93+
tp_num = get_dp_world_size()
94+
gpu_heads = gpu_kv_cache.shape[2]
95+
gpu_head_dim = gpu_kv_cache.shape[3]
96+
cpu_heads = cpu_kv_cache.shape[3]
97+
factor = (tp_num * gpu_heads) // cpu_heads
98+
cpu_head_offset = (tp_rank // factor) * gpu_heads * gpu_head_dim
99+
if tp_rank % factor != 0:
100+
# redundant kv does not need to offload
101+
return
102+
91103
grid = (page_num,)
92104
num_warps = 4
93105

@@ -108,6 +120,7 @@ def offload_gpu_kv_to_cpu(
108120
page_readies_ptr=page_readies,
109121
layer_num=gpu_kv_cache.shape[0],
110122
head_all_dim=head_all_dim,
123+
cpu_head_offset=cpu_head_offset,
111124
BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
112125
TOKEN_BLOCK=token_block_size,
113126
num_warps=num_warps,
@@ -133,7 +146,7 @@ def _load_cpu_cache_to_gpu(
133146
page_indexes_ptr,
134147
layer_num,
135148
head_all_dim,
136-
all_move_token_num,
149+
cpu_head_offset,
137150
BLOCK_HEAD_ALL_DIM: tl.constexpr,
138151
TOKEN_BLOCK: tl.constexpr,
139152
):
@@ -142,38 +155,32 @@ def _load_cpu_cache_to_gpu(
142155
if cpu_page_index == -1:
143156
return
144157

145-
gpu_stride0 = tl.cast(gpu_stride0, dtype=tl.int64)
146-
padded_size = TOKEN_BLOCK * tl.num_programs(0) - all_move_token_num
147-
head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM)
148158
token_range = block_index * TOKEN_BLOCK + tl.arange(0, TOKEN_BLOCK)
149-
token_range = token_range - padded_size
150-
151-
token_mask = token_range >= 0
159+
token_indexes = tl.load(token_indexes_ptr + token_range).to(tl.int64)
160+
head_all_dim_range = tl.arange(0, BLOCK_HEAD_ALL_DIM)
152161
head_dim_mask = head_all_dim_range < head_all_dim
153162

154-
token_indexes = tl.load(token_indexes_ptr + token_range, mask=token_mask, other=0).to(tl.int64)
155-
156-
cpu_page_index = tl.load(page_indexes_ptr + block_index)
163+
cpu_page_index = tl.load(page_indexes_ptr + block_index).to(tl.int64)
157164
for layer_index in range(layer_num):
158165
cpu_ptr = (
159166
cpu_kv_cache_ptr
160167
+ cpu_page_index * cpu_stride0
161168
+ layer_index * cpu_stride1
162169
+ tl.arange(0, TOKEN_BLOCK)[:, None] * cpu_stride2
163-
+ head_all_dim_range[None, :]
170+
+ (cpu_head_offset + head_all_dim_range[None, :])
164171
)
165172
cpu_data = tl.load(cpu_ptr, mask=head_dim_mask[None, :], other=0.0)
166173

167174
gpu_ptr = (
168175
gpu_kv_cache_ptr
169-
+ layer_index * gpu_stride0
176+
+ layer_index.to(tl.int64) * gpu_stride0
170177
+ token_indexes[:, None] * gpu_stride1
171178
+ head_all_dim_range[None, :]
172179
)
173180
tl.store(
174181
gpu_ptr,
175182
cpu_data,
176-
mask=token_mask[:, None] & head_dim_mask[None, :],
183+
mask=head_dim_mask[None, :],
177184
)
178185
return
179186

@@ -197,12 +204,22 @@ def load_cpu_kv_to_gpu(
197204
token_num = page_indexes.shape[0] * token_block_size
198205
assert mem_indexes.shape[0] >= token_num
199206
page_num = page_indexes.shape[0]
207+
assert len(mem_indexes) == page_num * token_block_size
200208
BLOCK_HEAD_ALL_DIM = triton.next_power_of_2(gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2])
201209

210+
# Calculate head offset for tensor parallelism
211+
tp_rank = get_current_rank_in_dp()
212+
tp_num = get_dp_world_size()
213+
gpu_heads = gpu_kv_cache.shape[2]
214+
gpu_head_dim = gpu_kv_cache.shape[3]
215+
cpu_heads = cpu_kv_cache.shape[3]
216+
factor = (tp_num * gpu_heads) // cpu_heads
217+
cpu_head_offset = (tp_rank // factor) * gpu_heads * gpu_head_dim
218+
202219
grid = (page_num,)
203220
num_warps = 1
204221

205-
_offload_gpu_kv_to_cpu[grid](
222+
_load_cpu_cache_to_gpu[grid](
206223
token_indexes_ptr=mem_indexes,
207224
gpu_kv_cache_ptr=gpu_kv_cache,
208225
gpu_stride0=gpu_kv_cache.stride(0),
@@ -218,7 +235,7 @@ def load_cpu_kv_to_gpu(
218235
page_indexes_ptr=page_indexes,
219236
layer_num=gpu_kv_cache.shape[0],
220237
head_all_dim=gpu_kv_cache.shape[-1] * gpu_kv_cache.shape[-2],
221-
all_move_token_num=len(mem_indexes),
238+
cpu_head_offset=cpu_head_offset,
222239
BLOCK_HEAD_ALL_DIM=BLOCK_HEAD_ALL_DIM,
223240
TOKEN_BLOCK=token_block_size,
224241
num_warps=num_warps,

lightllm/server/api_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
522522
parser.add_argument(
523523
"--enable_cpu_cache",
524524
action="store_true",
525-
help="""enable cpu cache to store kv cache.""",
525+
help="""enable cpu cache to store kv cache. prefer to use hugepages for better performance.""",
526526
)
527527
parser.add_argument(
528528
"--cpu_cache_storage_size",
@@ -533,7 +533,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
533533
parser.add_argument(
534534
"--cpu_cache_token_page_size",
535535
type=int,
536-
default=256,
536+
default=64,
537537
help="""The token page size of cpu cache""",
538538
)
539539
parser.add_argument("--enable_disk_cache", action="store_true", help="""enable disk cache to store kv cache.""")

lightllm/server/core/objs/atomic_array_lock.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from multiprocessing import shared_memory
44
from lightllm.utils.log_utils import init_logger
55
from lightllm.utils.shm_utils import create_or_link_shm
6+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
67

78
logger = init_logger(__name__)
89

@@ -18,6 +19,17 @@ def __init__(self, lock_name: str, lock_num: int):
1819
self.shm.buf.cast("i")[index] = 0
1920
return
2021

22+
def _init_shm(self):
23+
try:
24+
shm = shared_memory.SharedMemory(name=self.lock_name, create=True, size=self.dest_size)
25+
register_posix_shm_for_cleanup(self.lock_name)
26+
logger.info(f"create lock shm {self.lock_name}")
27+
except:
28+
shm = shared_memory.SharedMemory(name=self.lock_name, create=False, size=self.dest_size)
29+
logger.info(f"link lock shm {self.lock_name}")
30+
self.shm = shm
31+
return
32+
2133
def get_lock_context(self, lock_index: int) -> "AtomicLockItem":
2234
assert lock_index < self.lock_num
2335
return AtomicLockItem(self, lock_index)

lightllm/server/core/objs/atomic_lock.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from multiprocessing import shared_memory
44
from lightllm.utils.log_utils import init_logger
55
from lightllm.utils.shm_utils import create_or_link_shm
6+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
67

78
logger = init_logger(__name__)
89

@@ -16,6 +17,17 @@ def __init__(self, lock_name: str):
1617
self.shm.buf.cast("i")[0] = 0
1718
return
1819

20+
def _init_shm(self):
21+
try:
22+
shm = shared_memory.SharedMemory(name=self.lock_name, create=True, size=self.dest_size)
23+
logger.info(f"create lock shm {self.lock_name}")
24+
register_posix_shm_for_cleanup(self.lock_name)
25+
except:
26+
shm = shared_memory.SharedMemory(name=self.lock_name, create=False, size=self.dest_size)
27+
logger.info(f"link lock shm {self.lock_name}")
28+
self.shm = shm
29+
return
30+
1931
def __enter__(self):
2032
with atomics.atomicview(buffer=self.shm.buf, atype=atomics.INT) as a:
2133
while not a.cmpxchg_weak(0, 1):

lightllm/server/core/objs/rpc_shm.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from lightllm.utils.envs_utils import get_unique_server_name
77
from lightllm.utils.log_utils import init_logger
88
from lightllm.utils.shm_utils import create_or_link_shm
9+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
910

1011
logger = init_logger(__name__)
1112

@@ -20,8 +21,13 @@ def __init__(self):
2021

2122
def create_or_link_shm(self):
2223
self.shm = create_or_link_shm(self.name, LIGHTLLM_RPC_BYTE_SIZE)
24+
try:
25+
shm = shared_memory.SharedMemory(name=self.name, create=True, size=LIGHTLLM_RPC_BYTE_SIZE)
26+
register_posix_shm_for_cleanup(self.name)
27+
except:
28+
shm = shared_memory.SharedMemory(name=self.name, create=False, size=LIGHTLLM_RPC_BYTE_SIZE)
2329

24-
return
30+
return shm
2531

2632
def write_func_params(self, func_name, args):
2733
objs_bytes = pickle.dumps((func_name, args))
@@ -42,6 +48,24 @@ def __init__(self):
4248

4349
def create_or_link_shm(self):
4450
self.shm = create_or_link_shm(self.name, LIGHTLLM_RPC_RESULT_BYTE_SIZE)
51+
try:
52+
shm = shared_memory.SharedMemory(name=self.name, create=True, size=LIGHTLLM_RPC_RESULT_BYTE_SIZE)
53+
register_posix_shm_for_cleanup(self.name)
54+
except:
55+
shm = shared_memory.SharedMemory(name=self.name, create=False, size=LIGHTLLM_RPC_RESULT_BYTE_SIZE)
56+
57+
if shm.size != LIGHTLLM_RPC_RESULT_BYTE_SIZE:
58+
logger.warning(f"size not same, unlink shm {self.name} and create again")
59+
shm.close()
60+
shm.unlink()
61+
try:
62+
shm = shared_memory.SharedMemory(name=self.name, create=True, size=LIGHTLLM_RPC_RESULT_BYTE_SIZE)
63+
logger.info(f"create shm {self.name}")
64+
except:
65+
shm = shared_memory.SharedMemory(name=self.name, create=False, size=LIGHTLLM_RPC_RESULT_BYTE_SIZE)
66+
logger.info(f"link shm {self.name}")
67+
68+
self.shm = shm
4569
return
4670

4771
def write_func_result(self, func_name, ret):
@@ -68,12 +92,17 @@ def __init__(self, world_size):
6892

6993
def create_or_link_shm(self):
7094
self.shm = create_or_link_shm(self.name, self.dest_size)
95+
try:
96+
shm = shared_memory.SharedMemory(name=self.name, create=True, size=self.dest_size)
97+
register_posix_shm_for_cleanup(self.name)
98+
except:
99+
shm = shared_memory.SharedMemory(name=self.name, create=False, size=self.dest_size)
71100

72101
self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf)
73102
self.arr[:] = 0
74103
self.arr0 = self.arr[0 : self.world_size]
75104
self.arr1 = self.arr[self.world_size : 2 * self.world_size]
76-
return
105+
return shm
77106

78107
def add_mark(self, tp_rank: int):
79108
self.arr0[tp_rank] += 1

lightllm/server/core/objs/shm_array.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from multiprocessing import shared_memory
3+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
34
from lightllm.utils.log_utils import init_logger
45
from lightllm.utils.shm_utils import create_or_link_shm
56

@@ -18,6 +19,25 @@ def __init__(self, name, shape, dtype):
1819

1920
def create_shm(self):
2021
self.shm = create_or_link_shm(self.name, self.dest_size)
22+
try:
23+
shm = shared_memory.SharedMemory(name=self.name, create=True, size=self.dest_size)
24+
register_posix_shm_for_cleanup(self.name)
25+
except:
26+
shm = shared_memory.SharedMemory(name=self.name, create=False, size=self.dest_size)
27+
28+
if shm.size != self.dest_size:
29+
logger.warning(f"size not same, unlink shm {self.name} and create again")
30+
shm.close()
31+
shm.unlink()
32+
try:
33+
shm = shared_memory.SharedMemory(name=self.name, create=True, size=self.dest_size)
34+
logger.info(f"create shm {self.name}")
35+
register_posix_shm_for_cleanup(self.name)
36+
except:
37+
shm = shared_memory.SharedMemory(name=self.name, create=False, size=self.dest_size)
38+
logger.info(f"link shm {self.name}")
39+
40+
self.shm = shm # SharedMemory 对象一定要被持有,否则会被释放
2141
self.arr = np.ndarray(self.shape, dtype=self.dtype, buffer=self.shm.buf)
2242

2343
def link_shm(self):

lightllm/server/core/objs/shm_objs_io_buffer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from lightllm.utils.envs_utils import get_unique_server_name
66
from lightllm.utils.log_utils import init_logger
77
from lightllm.utils.shm_utils import create_or_link_shm
8+
from lightllm.utils.auto_shm_cleanup import register_posix_shm_for_cleanup
9+
from multiprocessing import shared_memory
810

911
LIGHTLLM_REQS_BUFFER_BYTE_SIZE = int(os.getenv("LIGHTLLM_REQS_BUFFER_BYTE_SIZE", 64 * 1024 * 1024)) # 默认64M buf
1012

@@ -53,6 +55,27 @@ def read_obj(self):
5355

5456
def _create_or_link_shm(self):
5557
self.shm = create_or_link_shm(self.name, LIGHTLLM_REQS_BUFFER_BYTE_SIZE)
58+
try:
59+
shm = shared_memory.SharedMemory(name=self.name, create=True, size=LIGHTLLM_REQS_BUFFER_BYTE_SIZE)
60+
logger.info(f"create shm {self.name}")
61+
register_posix_shm_for_cleanup(self.name)
62+
except:
63+
shm = shared_memory.SharedMemory(name=self.name, create=False, size=LIGHTLLM_REQS_BUFFER_BYTE_SIZE)
64+
logger.info(f"link shm {self.name}")
65+
66+
if shm.size != LIGHTLLM_REQS_BUFFER_BYTE_SIZE:
67+
logger.warning(f"size not same, unlink shm {self.name} and create again")
68+
shm.close()
69+
shm.unlink()
70+
try:
71+
shm = shared_memory.SharedMemory(name=self.name, create=True, size=LIGHTLLM_REQS_BUFFER_BYTE_SIZE)
72+
logger.info(f"create shm {self.name}")
73+
register_posix_shm_for_cleanup(self.name)
74+
except:
75+
shm = shared_memory.SharedMemory(name=self.name, create=False, size=LIGHTLLM_REQS_BUFFER_BYTE_SIZE)
76+
logger.info(f"link shm {self.name}")
77+
78+
self.shm = shm
5679
self.int_view = self.shm.buf.cast("i")
5780
# 前4个字节是特殊的计数用途,router写入后,被各个推理进程在拿去所有数据后,减1后归0
5881
self.int_view[0] = 0

0 commit comments

Comments
 (0)