14
14
# limitations under the License.
15
15
"""
16
16
17
+ import argparse
18
+ import json
17
19
import math
18
- import threading
19
20
import time
20
-
21
+ import threading
21
22
import numpy as np
22
23
import paddle
23
24
24
25
from fastdeploy .cache_manager .transfer_factory import IPCCommManager , RDMACommManager
26
+ from fastdeploy .config import SpeculativeConfig
25
27
from fastdeploy .inter_communicator import EngineWorkerQueue , IPCSignal
28
+ from fastdeploy .model_executor .ops .gpu import set_data_ipc
26
29
from fastdeploy .utils import get_logger
27
30
28
- logger = get_logger ("cache_messager" , "cache_messager.log" )
31
+
32
+ def parse_args ():
33
+ """
34
+ 从命令行解析参数
35
+ """
36
+ parser = argparse .ArgumentParser ("Cache Messager" )
37
+ parser .add_argument (
38
+ "--splitwise_role" ,
39
+ type = str ,
40
+ default = "mixed" ,
41
+ help = "splitwise role, can be decode, prefill or mixed" ,
42
+ )
43
+ parser .add_argument ("--rank" , type = int , default = 0 , help = "current rank" )
44
+ parser .add_argument ("--device_id" , type = int , default = 0 , help = "device id" )
45
+ parser .add_argument ("--num_hidden_layers" , type = int , default = 1 , help = "model num layers" )
46
+ parser .add_argument ("--head_dim" , type = int , default = 1 , help = "model head dim" )
47
+ parser .add_argument ("--kv_num_head" , type = int , default = 1 , help = "model kv num head" )
48
+ parser .add_argument ("--rdma_port" , type = str , default = "" , help = "rmda port" )
49
+ parser .add_argument ("--mp_num" , type = int , default = 1 , help = "number of model parallel" )
50
+ parser .add_argument ("--engine_pid" , type = str , default = None , help = "engine pid" )
51
+ parser .add_argument (
52
+ "--protocol" ,
53
+ type = str ,
54
+ default = "ipc" ,
55
+ help = "cache transfer protocol, only surport ipc now" ,
56
+ )
57
+ parser .add_argument ("--pod_ip" , type = str , default = "0.0.0.0" , help = "pod ip" )
58
+ parser .add_argument (
59
+ "--engine_worker_queue_port" ,
60
+ type = int ,
61
+ default = 9923 ,
62
+ help = "engine worker queue port" ,
63
+ )
64
+ parser .add_argument ("--num_gpu_blocks" , type = int , default = 1 , help = "gpu cache block number" )
65
+ parser .add_argument ("--block_size" , type = int , default = 64 , help = "cache block size(tokens)" )
66
+ parser .add_argument (
67
+ "--cache_dtype" ,
68
+ type = str ,
69
+ default = "bfloat16" ,
70
+ choices = ["uint8" , "bfloat16" ],
71
+ help = "cache dtype" ,
72
+ )
73
+ parser .add_argument (
74
+ "--speculative_config" ,
75
+ type = json .loads ,
76
+ default = "{}" ,
77
+ help = "speculative config" ,
78
+ )
79
+ parser .add_argument ("--local_data_parallel_id" , type = int , default = 0 )
80
+
81
+ args = parser .parse_args ()
82
+ return args
29
83
30
84
31
85
class CacheMessager :
@@ -43,7 +97,7 @@ def __init__(
43
97
gpu_cache_kvs ,
44
98
rank ,
45
99
nranks ,
46
- num_layers ,
100
+ num_hidden_layers ,
47
101
gpu_id = 0 ,
48
102
rdma_port = None ,
49
103
):
@@ -57,7 +111,7 @@ def __init__(
57
111
gpu_cache_kvs (dict): GPU kv cache
58
112
rank (int): current rank
59
113
nranks (int): global rank number
60
- num_layers (int): model layer number
114
+ num_hidden_layers (int): model layer number
61
115
gpu_id (int, optional): GPU ID
62
116
rdma_port (int, optional): RDMA port
63
117
@@ -86,13 +140,13 @@ def __init__(
86
140
logger .info (f"splitwise role: { splitwise_role } , { transfer_protocol } " f"rank: { rank } " )
87
141
88
142
# 1. initialize the cache_k_ptr_list and cache_v_ptr_list
89
- self .num_layers = num_layers
143
+ self .num_hidden_layers = num_hidden_layers
90
144
cache_k_ptr_list = []
91
145
cache_v_ptr_list = []
92
146
cache_k = []
93
147
cache_v = []
94
148
self .messager = {}
95
- for layer_idx in range (self .num_layers ):
149
+ for layer_idx in range (self .num_hidden_layers ):
96
150
key_cache = self .gpu_cache_kvs [f"key_caches_{ layer_idx } _rank{ self .rank } _device{ gpu_id } " ]
97
151
val_cache = self .gpu_cache_kvs [f"value_caches_{ layer_idx } _rank{ self .rank } _device{ gpu_id } " ]
98
152
cache_k .append (key_cache )
@@ -109,7 +163,7 @@ def __init__(
109
163
if key_cache .dtype == paddle .bfloat16 :
110
164
block_bytes *= 2
111
165
logger .info (
112
- f"layers { num_layers } cache_shape: { cache_shape } , max_block_num: { max_block_num } , "
166
+ f"layers { num_hidden_layers } cache_shape: { cache_shape } , max_block_num: { max_block_num } , "
113
167
f"block_bytes: { block_bytes } , dtype: { key_cache .dtype } "
114
168
)
115
169
self .block_bytes = block_bytes
@@ -144,17 +198,13 @@ def __init__(
144
198
self .cache_info = dict ()
145
199
self .rank_id = self .rank + local_data_parallel_id * self .nranks # align with engine worker rank (paddle.distributed.launch)
146
200
147
- layerwise_send_cache_thread = threading .Thread (target = self ._prefill_layerwise_send_cache_thread )
148
- layerwise_send_cache_thread .daemon = True
149
- layerwise_send_cache_thread .start ()
150
-
151
201
connect_rdma_thread = threading .Thread (target = self ._handle_connect_task )
152
202
connect_rdma_thread .daemon = True
153
203
connect_rdma_thread .start ()
154
204
155
205
logger .info (f"cache messager init finished, use { transfer_protocol } " )
156
206
157
- def _prefill_layerwise_send_cache_thread (self ):
207
+ def prefill_layerwise_send_cache_thread (self ):
158
208
"""
159
209
layerwise_send_cache_thread:
160
210
send cache to other instance
@@ -204,7 +254,7 @@ def _prefill_layerwise_send_cache_thread(self):
204
254
cache_info = self .engine_worker_queue .get_cache_info ()
205
255
206
256
if cache_info :
207
- logger .debug (f"cache info { cache_info } " )
257
+ logger .info (f"cache info { cache_info } " )
208
258
for info in cache_info :
209
259
if info ["request_id" ] in self .cache_info :
210
260
self .cache_info [info ["request_id" ]].update (info )
@@ -223,7 +273,7 @@ def _prefill_layerwise_send_cache_thread(self):
223
273
self .cache_info [info ["request_id" ]] = info
224
274
prefilled_layer_idx = layer_shm_value .value [0 ]
225
275
prefilled_step_idx = step_shm_value .value [0 ]
226
- if prefilled_layer_idx == self .num_layers - 1 :
276
+ if prefilled_layer_idx == self .num_hidden_layers - 1 :
227
277
time .sleep (0.001 )
228
278
prefilled_layer_idx = layer_shm_value .value [0 ]
229
279
prefilled_step_idx = step_shm_value .value [0 ]
@@ -234,7 +284,7 @@ def _prefill_layerwise_send_cache_thread(self):
234
284
if not self .cache_info :
235
285
time .sleep (0.001 )
236
286
continue
237
- logger .debug (f"prefilled_layer_idx: { prefilled_layer_idx } , prefilled_step_idx: { prefilled_step_idx } " )
287
+ logger .info (f"prefilled_layer_idx: { prefilled_layer_idx } , prefilled_step_idx: { prefilled_step_idx } " )
238
288
for req_id , item in list (self .cache_info .items ()):
239
289
if "status" not in item :
240
290
continue
@@ -251,7 +301,7 @@ def _prefill_layerwise_send_cache_thread(self):
251
301
target_id = int (item ["rdma_ports" ][self .rank ])
252
302
status = self .messager [current_transfer_protocol ].connect (target_ip , target_id )
253
303
if not status :
254
- logger .error (f"connect to { target_ip } :{ target_id } failed" )
304
+ logger .info (f"connect to { target_ip } :{ target_id } failed" )
255
305
item ["status" ] = "error"
256
306
self .engine_worker_queue .finish_request_barrier .wait ()
257
307
if self .rank == 0 :
@@ -263,7 +313,7 @@ def _prefill_layerwise_send_cache_thread(self):
263
313
src_block_ids = paddle .to_tensor (item ["src_block_ids" ], dtype = "int32" , place = "cpu" )
264
314
dest_block_ids = paddle .to_tensor (item ["dest_block_ids" ], dtype = "int32" , place = "cpu" )
265
315
if item ["current_id" ] < prefilled_step_idx :
266
- current_layer_idx = self .num_layers
316
+ current_layer_idx = self .num_hidden_layers
267
317
else :
268
318
current_layer_idx = prefilled_layer_idx + 1
269
319
@@ -281,7 +331,7 @@ def _prefill_layerwise_send_cache_thread(self):
281
331
self .engine_worker_queue .finish_request_barrier .wait ()
282
332
if self .rank == 0 :
283
333
self .engine_worker_queue .put_finished_req ([(item ["request_id" ], "write cache error" )])
284
- logger .error (
334
+ logger .info (
285
335
f"write cache failed, layer_idx: { layer_idx } , "
286
336
f"req_id: { item ['request_id' ]} , dest_ip: { target_ip } "
287
337
)
@@ -292,14 +342,14 @@ def _prefill_layerwise_send_cache_thread(self):
292
342
block_num = len (src_block_ids )
293
343
avg_time_per_block = cost_time * 1000 / block_num # ms
294
344
send_cache_speed = block_num * self .block_bytes / 1073741824 / cost_time # GB/s
295
- logger .debug (
345
+ logger .info (
296
346
f"finish write cache for a layer, { item ['request_id' ]} , { layer_idx } "
297
347
f" { current_transfer_protocol } "
298
348
f"block_num: { block_num } , send_cache_speed(GB/s): { round (send_cache_speed , 5 )} ,"
299
349
f"avg_time per block(ms): { round (avg_time_per_block , 5 )} "
300
350
)
301
351
item ["layer_idx" ] = current_layer_idx
302
- if item ["layer_idx" ] == self .num_layers :
352
+ if item ["layer_idx" ] == self .num_hidden_layers :
303
353
if item ["transfer_protocol" ] == "ipc" :
304
354
self .messager ["ipc" ].write_block_by_sync (target_id )
305
355
logger .info (f"finish write cache { item ['request_id' ]} " )
@@ -313,8 +363,8 @@ def _prefill_layerwise_send_cache_thread(self):
313
363
self .last_layer_idx = prefilled_layer_idx
314
364
315
365
except Exception as e :
316
- logger .error (f"prefill layerwise send cache thread has exception: { e } " )
317
-
366
+ logger .info (f"prefill layerwise send cache thread has exception: { e } " )
367
+
318
368
def _handle_connect_task (self ):
319
369
while True :
320
370
try :
@@ -333,3 +383,90 @@ def _handle_connect_task(self):
333
383
self .engine_worker_queue .put_connect_rdma_task_response (response )
334
384
except Exception as e :
335
385
logger .error (f"handle_connect_task has exception: { e } " )
386
+
387
+
388
+ def main ():
389
+ device = args .device_id
390
+ rank = args .rank
391
+ paddle .set_device (f"gpu:{ device } " )
392
+ cache_type = args .cache_dtype
393
+ speculative_config = SpeculativeConfig (args .speculative_config )
394
+ num_extra_layers = speculative_config .num_extra_cache_layer
395
+ num_extra_layer_gpu_blocks = int (args .num_gpu_blocks * speculative_config .num_gpu_block_expand_ratio )
396
+ gpu_cache_kvs = {}
397
+ gpu_cache_k_tensors = []
398
+ gpu_cache_v_tensors = []
399
+
400
+ for i in range (args .num_hidden_layers + num_extra_layers ):
401
+ num_gpu_blocks = args .num_gpu_blocks if i < args .num_hidden_layers else num_extra_layer_gpu_blocks
402
+
403
+ gpu_cache_kvs [f"key_caches_{ i } _rank{ rank } _device{ device } " ] = paddle .full (
404
+ shape = [
405
+ num_gpu_blocks ,
406
+ args .kv_num_head ,
407
+ args .block_size ,
408
+ args .head_dim ,
409
+ ],
410
+ fill_value = 0 ,
411
+ dtype = cache_type ,
412
+ )
413
+ gpu_cache_k_tensors .append (gpu_cache_kvs [f"key_caches_{ i } _rank{ rank } _device{ device } " ])
414
+ gpu_cache_kvs [f"value_caches_{ i } _rank{ rank } _device{ device } " ] = paddle .full (
415
+ shape = [
416
+ num_gpu_blocks ,
417
+ args .kv_num_head ,
418
+ args .block_size ,
419
+ args .head_dim ,
420
+ ],
421
+ fill_value = 0 ,
422
+ dtype = cache_type ,
423
+ )
424
+ gpu_cache_v_tensors .append (gpu_cache_kvs [f"value_caches_{ i } _rank{ rank } _device{ device } " ])
425
+
426
+ set_data_ipc (
427
+ gpu_cache_kvs [f"key_caches_{ i } _rank{ rank } _device{ device } " ],
428
+ f"key_caches_{ i } _rank{ rank } .device{ device } " ,
429
+ )
430
+ set_data_ipc (
431
+ gpu_cache_kvs [f"value_caches_{ i } _rank{ rank } _device{ device } " ],
432
+ f"value_caches_{ i } _rank{ rank } .device{ device } " ,
433
+ )
434
+ cache_kv_size_byte = sum ([tmp .numel () * 1 for key , tmp in gpu_cache_kvs .items ()])
435
+ logger .info (f"device :{ device } " )
436
+ logger .info (f"cache_kv_size_byte : { cache_kv_size_byte } " )
437
+ logger .info (f"done init cache (full) gmem alloc : { paddle .device .cuda .memory_allocated ()} " )
438
+
439
+ cache_messager = CacheMessager (
440
+ splitwise_role = args .splitwise_role ,
441
+ transfer_protocol = args .protocol ,
442
+ pod_ip = args .pod_ip ,
443
+ engine_worker_queue_port = args .engine_worker_queue_port ,
444
+ local_data_parallel_id = args .local_data_parallel_id ,
445
+ gpu_cache_kvs = gpu_cache_kvs ,
446
+ rank = rank ,
447
+ nranks = args .mp_num ,
448
+ num_hidden_layers = args .num_hidden_layers + num_extra_layers ,
449
+ gpu_id = device ,
450
+ rdma_port = args .rdma_port ,
451
+ )
452
+
453
+ cache_ready_signal_data = np .zeros (shape = [args .mp_num ], dtype = np .int32 )
454
+ cache_ready_signal = IPCSignal (
455
+ name = "cache_ready_signal" ,
456
+ array = cache_ready_signal_data ,
457
+ dtype = np .int32 ,
458
+ suffix = args .engine_pid ,
459
+ create = False ,
460
+ )
461
+ cache_ready_signal .value [rank ] = 1
462
+ cache_messager .prefill_layerwise_send_cache_thread ()
463
+
464
+
465
+ if __name__ == "__main__" :
466
+
467
+ args = parse_args ()
468
+ logger = get_logger ("cache_messager" , "cache_messager.log" )
469
+
470
+ logger .info ("create cache messager..." )
471
+ logger .info (f"{ args } " )
472
+ main ()
0 commit comments