Skip to content

Commit dcf9c2d

Browse files
authored
[Feature] Optimize prefix cache (#3208)
* [LLM] support ep * Update worker_process.py * Update expert_service.py * Update worker_process.py * format files * optimize prefix cache * optimize prefix cache * optimize prefix cache * pre commit format * pre commit format * pre commit format * Update cache_messager.py
1 parent 9f99718 commit dcf9c2d

File tree

7 files changed

+315
-148
lines changed

7 files changed

+315
-148
lines changed

fastdeploy/cache_manager/cache_messager.py

Lines changed: 160 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,72 @@
1414
# limitations under the License.
1515
"""
1616

17+
import argparse
18+
import json
1719
import math
18-
import threading
1920
import time
20-
21+
import threading
2122
import numpy as np
2223
import paddle
2324

2425
from fastdeploy.cache_manager.transfer_factory import IPCCommManager, RDMACommManager
26+
from fastdeploy.config import SpeculativeConfig
2527
from fastdeploy.inter_communicator import EngineWorkerQueue, IPCSignal
28+
from fastdeploy.model_executor.ops.gpu import set_data_ipc
2629
from fastdeploy.utils import get_logger
2730

28-
logger = get_logger("cache_messager", "cache_messager.log")
31+
32+
def parse_args():
33+
"""
34+
从命令行解析参数
35+
"""
36+
parser = argparse.ArgumentParser("Cache Messager")
37+
parser.add_argument(
38+
"--splitwise_role",
39+
type=str,
40+
default="mixed",
41+
help="splitwise role, can be decode, prefill or mixed",
42+
)
43+
parser.add_argument("--rank", type=int, default=0, help="current rank")
44+
parser.add_argument("--device_id", type=int, default=0, help="device id")
45+
parser.add_argument("--num_hidden_layers", type=int, default=1, help="model num layers")
46+
parser.add_argument("--head_dim", type=int, default=1, help="model head dim")
47+
parser.add_argument("--kv_num_head", type=int, default=1, help="model kv num head")
48+
parser.add_argument("--rdma_port", type=str, default="", help="rmda port")
49+
parser.add_argument("--mp_num", type=int, default=1, help="number of model parallel")
50+
parser.add_argument("--engine_pid", type=str, default=None, help="engine pid")
51+
parser.add_argument(
52+
"--protocol",
53+
type=str,
54+
default="ipc",
55+
help="cache transfer protocol, only surport ipc now",
56+
)
57+
parser.add_argument("--pod_ip", type=str, default="0.0.0.0", help="pod ip")
58+
parser.add_argument(
59+
"--engine_worker_queue_port",
60+
type=int,
61+
default=9923,
62+
help="engine worker queue port",
63+
)
64+
parser.add_argument("--num_gpu_blocks", type=int, default=1, help="gpu cache block number")
65+
parser.add_argument("--block_size", type=int, default=64, help="cache block size(tokens)")
66+
parser.add_argument(
67+
"--cache_dtype",
68+
type=str,
69+
default="bfloat16",
70+
choices=["uint8", "bfloat16"],
71+
help="cache dtype",
72+
)
73+
parser.add_argument(
74+
"--speculative_config",
75+
type=json.loads,
76+
default="{}",
77+
help="speculative config",
78+
)
79+
parser.add_argument("--local_data_parallel_id", type=int, default=0)
80+
81+
args = parser.parse_args()
82+
return args
2983

3084

3185
class CacheMessager:
@@ -43,7 +97,7 @@ def __init__(
4397
gpu_cache_kvs,
4498
rank,
4599
nranks,
46-
num_layers,
100+
num_hidden_layers,
47101
gpu_id=0,
48102
rdma_port=None,
49103
):
@@ -57,7 +111,7 @@ def __init__(
57111
gpu_cache_kvs (dict): GPU kv cache
58112
rank (int): current rank
59113
nranks (int): global rank number
60-
num_layers (int): model layer number
114+
num_hidden_layers (int): model layer number
61115
gpu_id (int, optional): GPU ID
62116
rdma_port (int, optional): RDMA port
63117
@@ -86,13 +140,13 @@ def __init__(
86140
logger.info(f"splitwise role: {splitwise_role}, {transfer_protocol}" f"rank: {rank}")
87141

88142
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
89-
self.num_layers = num_layers
143+
self.num_hidden_layers = num_hidden_layers
90144
cache_k_ptr_list = []
91145
cache_v_ptr_list = []
92146
cache_k = []
93147
cache_v = []
94148
self.messager = {}
95-
for layer_idx in range(self.num_layers):
149+
for layer_idx in range(self.num_hidden_layers):
96150
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
97151
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
98152
cache_k.append(key_cache)
@@ -109,7 +163,7 @@ def __init__(
109163
if key_cache.dtype == paddle.bfloat16:
110164
block_bytes *= 2
111165
logger.info(
112-
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
166+
f"layers {num_hidden_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
113167
f"block_bytes: {block_bytes}, dtype: {key_cache.dtype}"
114168
)
115169
self.block_bytes = block_bytes
@@ -144,17 +198,13 @@ def __init__(
144198
self.cache_info = dict()
145199
self.rank_id = self.rank + local_data_parallel_id * self.nranks # align with engine worker rank (paddle.distributed.launch)
146200

147-
layerwise_send_cache_thread = threading.Thread(target=self._prefill_layerwise_send_cache_thread)
148-
layerwise_send_cache_thread.daemon = True
149-
layerwise_send_cache_thread.start()
150-
151201
connect_rdma_thread = threading.Thread(target=self._handle_connect_task)
152202
connect_rdma_thread.daemon = True
153203
connect_rdma_thread.start()
154204

155205
logger.info(f"cache messager init finished, use {transfer_protocol}")
156206

157-
def _prefill_layerwise_send_cache_thread(self):
207+
def prefill_layerwise_send_cache_thread(self):
158208
"""
159209
layerwise_send_cache_thread:
160210
send cache to other instance
@@ -204,7 +254,7 @@ def _prefill_layerwise_send_cache_thread(self):
204254
cache_info = self.engine_worker_queue.get_cache_info()
205255

206256
if cache_info:
207-
logger.debug(f"cache info {cache_info}")
257+
logger.info(f"cache info {cache_info}")
208258
for info in cache_info:
209259
if info["request_id"] in self.cache_info:
210260
self.cache_info[info["request_id"]].update(info)
@@ -223,7 +273,7 @@ def _prefill_layerwise_send_cache_thread(self):
223273
self.cache_info[info["request_id"]] = info
224274
prefilled_layer_idx = layer_shm_value.value[0]
225275
prefilled_step_idx = step_shm_value.value[0]
226-
if prefilled_layer_idx == self.num_layers - 1:
276+
if prefilled_layer_idx == self.num_hidden_layers - 1:
227277
time.sleep(0.001)
228278
prefilled_layer_idx = layer_shm_value.value[0]
229279
prefilled_step_idx = step_shm_value.value[0]
@@ -234,7 +284,7 @@ def _prefill_layerwise_send_cache_thread(self):
234284
if not self.cache_info:
235285
time.sleep(0.001)
236286
continue
237-
logger.debug(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
287+
logger.info(f"prefilled_layer_idx: {prefilled_layer_idx}, prefilled_step_idx: {prefilled_step_idx}")
238288
for req_id, item in list(self.cache_info.items()):
239289
if "status" not in item:
240290
continue
@@ -251,7 +301,7 @@ def _prefill_layerwise_send_cache_thread(self):
251301
target_id = int(item["rdma_ports"][self.rank])
252302
status = self.messager[current_transfer_protocol].connect(target_ip, target_id)
253303
if not status:
254-
logger.error(f"connect to {target_ip}:{target_id} failed")
304+
logger.info(f"connect to {target_ip}:{target_id} failed")
255305
item["status"] = "error"
256306
self.engine_worker_queue.finish_request_barrier.wait()
257307
if self.rank == 0:
@@ -263,7 +313,7 @@ def _prefill_layerwise_send_cache_thread(self):
263313
src_block_ids = paddle.to_tensor(item["src_block_ids"], dtype="int32", place="cpu")
264314
dest_block_ids = paddle.to_tensor(item["dest_block_ids"], dtype="int32", place="cpu")
265315
if item["current_id"] < prefilled_step_idx:
266-
current_layer_idx = self.num_layers
316+
current_layer_idx = self.num_hidden_layers
267317
else:
268318
current_layer_idx = prefilled_layer_idx + 1
269319

@@ -281,7 +331,7 @@ def _prefill_layerwise_send_cache_thread(self):
281331
self.engine_worker_queue.finish_request_barrier.wait()
282332
if self.rank == 0:
283333
self.engine_worker_queue.put_finished_req([(item["request_id"], "write cache error")])
284-
logger.error(
334+
logger.info(
285335
f"write cache failed, layer_idx: {layer_idx}, "
286336
f"req_id: {item['request_id']}, dest_ip: {target_ip}"
287337
)
@@ -292,14 +342,14 @@ def _prefill_layerwise_send_cache_thread(self):
292342
block_num = len(src_block_ids)
293343
avg_time_per_block = cost_time * 1000 / block_num # ms
294344
send_cache_speed = block_num * self.block_bytes / 1073741824 / cost_time # GB/s
295-
logger.debug(
345+
logger.info(
296346
f"finish write cache for a layer, {item['request_id']}, {layer_idx}"
297347
f" {current_transfer_protocol}"
298348
f"block_num: {block_num}, send_cache_speed(GB/s): {round(send_cache_speed, 5)},"
299349
f"avg_time per block(ms): {round(avg_time_per_block, 5)}"
300350
)
301351
item["layer_idx"] = current_layer_idx
302-
if item["layer_idx"] == self.num_layers:
352+
if item["layer_idx"] == self.num_hidden_layers:
303353
if item["transfer_protocol"] == "ipc":
304354
self.messager["ipc"].write_block_by_sync(target_id)
305355
logger.info(f"finish write cache {item['request_id']}")
@@ -313,8 +363,8 @@ def _prefill_layerwise_send_cache_thread(self):
313363
self.last_layer_idx = prefilled_layer_idx
314364

315365
except Exception as e:
316-
logger.error(f"prefill layerwise send cache thread has exception: {e}")
317-
366+
logger.info(f"prefill layerwise send cache thread has exception: {e}")
367+
318368
def _handle_connect_task(self):
319369
while True:
320370
try:
@@ -333,3 +383,90 @@ def _handle_connect_task(self):
333383
self.engine_worker_queue.put_connect_rdma_task_response(response)
334384
except Exception as e:
335385
logger.error(f"handle_connect_task has exception: {e}")
386+
387+
388+
def main():
389+
device = args.device_id
390+
rank = args.rank
391+
paddle.set_device(f"gpu:{device}")
392+
cache_type = args.cache_dtype
393+
speculative_config = SpeculativeConfig(args.speculative_config)
394+
num_extra_layers = speculative_config.num_extra_cache_layer
395+
num_extra_layer_gpu_blocks = int(args.num_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
396+
gpu_cache_kvs = {}
397+
gpu_cache_k_tensors = []
398+
gpu_cache_v_tensors = []
399+
400+
for i in range(args.num_hidden_layers + num_extra_layers):
401+
num_gpu_blocks = args.num_gpu_blocks if i < args.num_hidden_layers else num_extra_layer_gpu_blocks
402+
403+
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
404+
shape=[
405+
num_gpu_blocks,
406+
args.kv_num_head,
407+
args.block_size,
408+
args.head_dim,
409+
],
410+
fill_value=0,
411+
dtype=cache_type,
412+
)
413+
gpu_cache_k_tensors.append(gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
414+
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
415+
shape=[
416+
num_gpu_blocks,
417+
args.kv_num_head,
418+
args.block_size,
419+
args.head_dim,
420+
],
421+
fill_value=0,
422+
dtype=cache_type,
423+
)
424+
gpu_cache_v_tensors.append(gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
425+
426+
set_data_ipc(
427+
gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
428+
f"key_caches_{i}_rank{rank}.device{device}",
429+
)
430+
set_data_ipc(
431+
gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
432+
f"value_caches_{i}_rank{rank}.device{device}",
433+
)
434+
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in gpu_cache_kvs.items()])
435+
logger.info(f"device :{device}")
436+
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
437+
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
438+
439+
cache_messager = CacheMessager(
440+
splitwise_role=args.splitwise_role,
441+
transfer_protocol=args.protocol,
442+
pod_ip=args.pod_ip,
443+
engine_worker_queue_port=args.engine_worker_queue_port,
444+
local_data_parallel_id=args.local_data_parallel_id,
445+
gpu_cache_kvs=gpu_cache_kvs,
446+
rank=rank,
447+
nranks=args.mp_num,
448+
num_hidden_layers=args.num_hidden_layers + num_extra_layers,
449+
gpu_id=device,
450+
rdma_port=args.rdma_port,
451+
)
452+
453+
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
454+
cache_ready_signal = IPCSignal(
455+
name="cache_ready_signal",
456+
array=cache_ready_signal_data,
457+
dtype=np.int32,
458+
suffix=args.engine_pid,
459+
create=False,
460+
)
461+
cache_ready_signal.value[rank] = 1
462+
cache_messager.prefill_layerwise_send_cache_thread()
463+
464+
465+
if __name__ == "__main__":
466+
467+
args = parse_args()
468+
logger = get_logger("cache_messager", "cache_messager.log")
469+
470+
logger.info("create cache messager...")
471+
logger.info(f"{args}")
472+
main()

0 commit comments

Comments
 (0)