16
16
17
17
import argparse
18
18
import concurrent .futures
19
+ import gc
19
20
import json
20
21
import queue
22
+ import threading
21
23
import time
22
24
import traceback
23
25
24
26
import numpy as np
25
27
import paddle
26
28
29
+ from fastdeploy import envs
27
30
from fastdeploy .cache_manager .cache_data import CacheStatus
28
31
from fastdeploy .config import SpeculativeConfig
29
- from fastdeploy .inter_communicator import EngineCacheQueue , IPCSignal
32
+ from fastdeploy .inter_communicator import EngineCacheQueue , IPCSignal , KVCacheStatus
30
33
from fastdeploy .model_executor .ops .gpu import (
31
34
cuda_host_alloc ,
35
+ cuda_host_free ,
36
+ set_data_ipc ,
32
37
share_external_data ,
33
38
swap_cache_all_layers ,
39
+ unset_data_ipc ,
34
40
)
35
41
from fastdeploy .utils import get_logger
36
42
@@ -93,6 +99,7 @@ def parse_args():
93
99
help = "speculative config" ,
94
100
)
95
101
parser .add_argument ("--local_data_parallel_id" , type = int , default = 0 )
102
+ parser .add_argument ("--create_cache_tensor" , action = "store_true" )
96
103
97
104
args = parser .parse_args ()
98
105
return args
@@ -110,7 +117,6 @@ def __init__(self, args):
110
117
111
118
device = args .device_id
112
119
rank = args .rank
113
- paddle .set_device (f"gpu:{ device } " )
114
120
self .gpu_cache_kvs = {}
115
121
self .cpu_cache_kvs = {}
116
122
self .gpu_cache_k_tensors = []
@@ -126,6 +132,7 @@ def __init__(self, args):
126
132
self .n_ranks = args .mp_num
127
133
self .rank = rank
128
134
self .device = device
135
+ self .engine_pid = args .engine_pid
129
136
130
137
address = (args .pod_ip , args .cache_queue_port )
131
138
self .cache_task_queue = EngineCacheQueue (
@@ -136,57 +143,27 @@ def __init__(self, args):
136
143
local_data_parallel_id = args .local_data_parallel_id ,
137
144
)
138
145
139
- self .num_cpu_blocks = args .num_cpu_blocks
140
-
141
- cache_type = args .cache_dtype
142
- cache_shape = [
143
- args .num_gpu_blocks ,
144
- args .kv_num_head ,
145
- args .block_size ,
146
- args .head_dim ,
147
- ]
148
-
149
- for i in range (args .num_layers + self .num_extra_layers ):
150
- num_gpu_blocks = args .num_gpu_blocks if i < args .num_layers else self .num_extra_layer_gpu_blocks
151
- cache_shape [0 ] = num_gpu_blocks
152
- key_name = f"key_caches_{ i } _rank{ rank } .device{ device } "
153
- value_name = f"value_caches_{ i } _rank{ rank } .device{ device } "
154
- key_cache = paddle .empty (shape = [], dtype = cache_type )
155
- value_cache = paddle .empty (shape = [], dtype = cache_type )
156
- key_cache = share_external_data (key_cache , key_name , cache_shape )
157
- value_cache = share_external_data (value_cache , value_name , cache_shape )
158
- self .gpu_cache_kvs [key_name ] = key_cache
159
- self .gpu_cache_kvs [value_name ] = value_cache
160
- self .gpu_cache_k_tensors .append (self .gpu_cache_kvs [key_name ])
161
- self .gpu_cache_v_tensors .append (self .gpu_cache_kvs [value_name ])
162
-
163
- cache_kv_size_byte = sum ([tmp .numel () * 1 for key , tmp in self .gpu_cache_kvs .items ()])
164
- logger .info (f"device :{ self .device } " )
165
- logger .info (f"cache_kv_size_byte : { cache_kv_size_byte } " )
166
- logger .info (f"done init cache (full) gmem alloc : { paddle .device .cuda .memory_allocated ()} " )
167
-
168
- paddle .set_device ("cpu" )
169
- self .k_dst_ptrs = []
170
- self .v_dst_ptrs = []
171
- for i in range (args .num_layers + self .num_extra_layers ):
172
- self .cpu_cache_kvs [f"key_caches_{ i } _rank{ rank } " ] = cuda_host_alloc (
173
- args .num_cpu_blocks * args .bytes_per_layer_per_block
174
- )
175
- self .k_dst_ptrs .append (self .cpu_cache_kvs [f"key_caches_{ i } _rank{ rank } " ])
176
- self .cpu_cache_kvs [f"value_caches_{ i } _rank{ rank } " ] = cuda_host_alloc (
177
- args .num_cpu_blocks * args .bytes_per_layer_per_block
178
- )
179
- self .v_dst_ptrs .append (self .cpu_cache_kvs [f"value_caches_{ i } _rank{ rank } " ])
180
-
181
146
cache_ready_signal_data = np .zeros (shape = [args .mp_num ], dtype = np .int32 )
182
147
self .cache_ready_signal = IPCSignal (
183
148
name = "cache_ready_signal" ,
184
149
array = cache_ready_signal_data ,
185
150
dtype = np .int32 ,
186
- suffix = args .engine_pid ,
151
+ suffix = self .engine_pid ,
152
+ create = False ,
153
+ )
154
+ swap_space_ready_data = np .zeros (shape = [args .mp_num ], dtype = np .int32 )
155
+ self .swap_space_ready_signal = IPCSignal (
156
+ name = "swap_space_ready_signal" ,
157
+ array = swap_space_ready_data ,
158
+ dtype = np .int32 ,
159
+ suffix = self .engine_pid ,
187
160
create = False ,
188
161
)
189
- self .cache_ready_signal .value [self .rank ] = 1
162
+
163
+ self .num_cpu_blocks = args .num_cpu_blocks
164
+
165
+ self ._init_cpu_cache (args )
166
+ self ._init_gpu_cache (args )
190
167
191
168
cache_task_broadcast_data = np .zeros (shape = [1 ], dtype = np .int32 )
192
169
self .cache_task_broadcast_signal = IPCSignal (
@@ -197,6 +174,76 @@ def __init__(self, args):
197
174
create = False ,
198
175
)
199
176
177
+ threading .Thread (target = self .clear_or_update_caches , args = [args ], daemon = True ).start ()
178
+
179
+ def _init_gpu_cache (self , args ):
180
+
181
+ if not args .create_cache_tensor :
182
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] Waiting for runners to create kv cache." )
183
+ while self .cache_ready_signal .value [self .rank ] != 1 :
184
+ time .sleep (0.1 )
185
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] OK! Stop waiting." )
186
+
187
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] Initializing kv cache for all layers." )
188
+ paddle .set_device (f"gpu:{ self .device } " )
189
+ for i in range (args .num_layers + self .num_extra_layers ):
190
+ num_gpu_blocks = args .num_gpu_blocks if i < args .num_layers else self .num_extra_layer_gpu_blocks
191
+ cache_shape = [num_gpu_blocks , args .kv_num_head , args .block_size , args .head_dim ]
192
+ key_name = f"key_caches_{ i } _rank{ self .rank } .device{ self .device } "
193
+ val_name = f"value_caches_{ i } _rank{ self .rank } .device{ self .device } "
194
+
195
+ if args .create_cache_tensor :
196
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] ..creating kv cache for layer { i } : { cache_shape } " )
197
+ key_cache = paddle .full (shape = cache_shape , fill_value = 0 , dtype = args .cache_dtype )
198
+ val_cache = paddle .full (shape = cache_shape , fill_value = 0 , dtype = args .cache_dtype )
199
+ set_data_ipc (key_cache , key_name )
200
+ set_data_ipc (val_cache , val_name )
201
+ else :
202
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] ..attaching kv cache for layer { i } : { cache_shape } " )
203
+ key_cache = paddle .empty (shape = [], dtype = args .cache_dtype )
204
+ val_cache = paddle .empty (shape = [], dtype = args .cache_dtype )
205
+ key_cache = share_external_data (key_cache , key_name , cache_shape )
206
+ val_cache = share_external_data (val_cache , val_name , cache_shape )
207
+
208
+ self .gpu_cache_kvs [key_name ] = key_cache
209
+ self .gpu_cache_kvs [val_name ] = val_cache
210
+ self .gpu_cache_k_tensors .append (self .gpu_cache_kvs [key_name ])
211
+ self .gpu_cache_v_tensors .append (self .gpu_cache_kvs [val_name ])
212
+
213
+ if args .create_cache_tensor :
214
+ logger .info ("[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!" )
215
+ self .cache_ready_signal .value [self .rank ] = 1
216
+
217
+ cache_kv_size_byte = sum ([tmp .numel () * 1 for key , tmp in self .gpu_cache_kvs .items ()])
218
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] device :{ self .device } " )
219
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] cache_kv_size_byte : { cache_kv_size_byte } " )
220
+ logger .info (
221
+ f"[rank { self .rank } /{ self .n_ranks } ] done init cache (full) gmem alloc : { paddle .device .cuda .memory_allocated ()} "
222
+ )
223
+
224
+ def _init_cpu_cache (self , args ):
225
+ if args .num_cpu_blocks == 0 :
226
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] 💡 no swap space (cpu cache) is specified." )
227
+ self .swap_space_ready_signal .value [self .rank ] = 1
228
+ return
229
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] Initializing swap space (cpu cache) for all layers." )
230
+ paddle .set_device ("cpu" )
231
+ self .k_dst_ptrs = []
232
+ self .v_dst_ptrs = []
233
+ for i in range (args .num_layers + self .num_extra_layers ):
234
+ key_name = f"key_caches_{ i } _rank{ self .rank } "
235
+ val_name = f"value_caches_{ i } _rank{ self .rank } "
236
+ need_to_allocate_bytes = args .num_cpu_blocks * args .bytes_per_layer_per_block
237
+ logger .info (
238
+ f"[rank { self .rank } /{ self .n_ranks } ] ..creating cpu cache for layer { i } : { 2 * need_to_allocate_bytes / 1024 ** 3 :.2f} GB"
239
+ )
240
+ self .cpu_cache_kvs [key_name ] = cuda_host_alloc (need_to_allocate_bytes )
241
+ self .k_dst_ptrs .append (self .cpu_cache_kvs [key_name ])
242
+ self .cpu_cache_kvs [val_name ] = cuda_host_alloc (need_to_allocate_bytes )
243
+ self .v_dst_ptrs .append (self .cpu_cache_kvs [val_name ])
244
+ logger .info (f"[rank { self .rank } /{ self .n_ranks } ] ✅ swap space (cpu cache) is ready!" )
245
+ self .swap_space_ready_signal .value [self .rank ] = 1
246
+
200
247
def _do_swap_to_cpu_task (
201
248
self ,
202
249
swap_node_ids ,
@@ -394,6 +441,92 @@ def _transfer_data(
394
441
transfer_task_id ,
395
442
)
396
443
444
+ def clear_or_update_caches (self , args ):
445
+ logger .info ("Start a thread to clear/restore kv cache when model weights are cleared/updated." )
446
+ logger .info (f"FD_ENABLE_SWAP_SPACE_CLEARING={ envs .FD_ENABLE_SWAP_SPACE_CLEARING } " )
447
+ kv_cache_status = np .zeros ([1 ], dtype = np .int32 )
448
+ kv_cache_status_signal = IPCSignal (
449
+ name = "kv_cache_status" ,
450
+ array = kv_cache_status ,
451
+ dtype = np .int32 ,
452
+ suffix = self .engine_pid ,
453
+ create = False ,
454
+ )
455
+ while True :
456
+ if kv_cache_status_signal .value [0 ] == KVCacheStatus .CLEARING :
457
+ try :
458
+ logger .info (
459
+ f"[rank { self .rank } /{ self .n_ranks } ] Start clearing caches { self .cache_ready_signal .value } "
460
+ )
461
+ # clear cpu caches
462
+ if envs .FD_ENABLE_SWAP_SPACE_CLEARING :
463
+ paddle .set_device ("cpu" )
464
+ for ptrs in self .k_dst_ptrs + self .v_dst_ptrs :
465
+ cuda_host_free (ptrs )
466
+ self .cpu_cache_kvs .clear ()
467
+ self .k_dst_ptrs .clear ()
468
+ self .v_dst_ptrs .clear ()
469
+ gc .collect ()
470
+ # reset swap_space_ready_signal
471
+ self .swap_space_ready_signal .value [self .rank ] = 0
472
+ while np .sum (self .swap_space_ready_signal .value ) != 0 :
473
+ time .sleep (0.1 )
474
+
475
+ # clear gpu caches
476
+ paddle .set_device (f"gpu:{ self .device } " )
477
+ for name , tensor in self .gpu_cache_kvs .items ():
478
+ unset_data_ipc (tensor , name , True , False )
479
+ self .gpu_cache_kvs .clear ()
480
+ self .gpu_cache_k_tensors .clear ()
481
+ self .gpu_cache_v_tensors .clear ()
482
+
483
+ # reset cache_ready_signal
484
+ self .cache_ready_signal .value [self .rank ] = 0
485
+ logger .info (
486
+ f"[rank { self .rank } /{ self .n_ranks } ] Finish clearing caches { self .cache_ready_signal .value } "
487
+ )
488
+
489
+ # wait for all ranks caches to be cleared
490
+ if np .sum (self .cache_ready_signal .value ) != 0 :
491
+ time .sleep (0.1 )
492
+
493
+ # reset kv_cache_status_signal
494
+ kv_cache_status_signal .value [0 ] = KVCacheStatus .CLEARED
495
+ logger .info ("All ranks finish clearing caches" )
496
+
497
+ except Exception as e :
498
+ logger .error (f"[rank { self .rank } /{ self .n_ranks } ] Failed to clear caches: { e } " )
499
+
500
+ elif kv_cache_status_signal .value [0 ] == KVCacheStatus .UPDATING :
501
+ try :
502
+ logger .info (
503
+ f"[rank { self .rank } /{ self .n_ranks } ] Start restoring caches { self .cache_ready_signal .value } "
504
+ )
505
+ # restore cpu cache
506
+ if envs .FD_ENABLE_SWAP_SPACE_CLEARING :
507
+ self ._init_cpu_cache (args )
508
+ while np .sum (self .swap_space_ready_signal .value ) != args .mp_num :
509
+ time .sleep (0.1 )
510
+
511
+ # restore gpu cache and set cache_ready_signal
512
+ self ._init_gpu_cache (args )
513
+ logger .info (
514
+ f"[rank { self .rank } /{ self .n_ranks } ] Finish restoring caches { self .cache_ready_signal .value } "
515
+ )
516
+
517
+ # wait for all ranks caches to be ready
518
+ while np .sum (self .cache_ready_signal .value ) != args .mp_num :
519
+ time .sleep (0.1 )
520
+
521
+ # set kv_cache_status_signal
522
+ logger .info ("All ranks finish restoring caches" )
523
+ kv_cache_status_signal .value [0 ] = KVCacheStatus .NORMAL
524
+
525
+ except Exception as e :
526
+ logger .error (f"[rank { self .rank } /{ self .n_ranks } ] Failed to restore caches: { e } " )
527
+
528
+ time .sleep (0.1 )
529
+
397
530
398
531
def main ():
399
532
"""
0 commit comments