22
33import triton
44import triton .language as tl
5+ from lightllm .utils .dist_utils import get_current_rank_in_dp , get_dp_world_size
56
67
78@triton .jit
@@ -72,16 +73,13 @@ def offload_gpu_kv_to_cpu(
7273):
7374 """
7475 this function is used to offload GPU KV cache to CPU KV cache.
75- Supports tensor parallelism (TP > 1).
7676 Args:
7777 token_indexes: (token_num,)
7878 gpu_kv_cache: (layer_num, token_num, head_num, head_dim)
7979 cpu_kv_cache: (all_page_num, layer_num, token_block_size, head_num, head_dim)
8080 page_indexes: (page_num,)
8181 page_readies: (page_num,)
8282 """
83- from lightllm .utils .dist_utils import get_current_rank_in_dp , get_dp_world_size
84-
8583 token_block_size = cpu_kv_cache .shape [2 ]
8684 token_num = page_indexes .shape [0 ] * token_block_size
8785 assert token_indexes .shape [0 ] >= token_num
@@ -92,9 +90,15 @@ def offload_gpu_kv_to_cpu(
9290
9391 # Calculate head offset for tensor parallelism
9492 tp_rank = get_current_rank_in_dp ()
93+ tp_num = get_dp_world_size ()
9594 gpu_heads = gpu_kv_cache .shape [2 ]
9695 gpu_head_dim = gpu_kv_cache .shape [3 ]
97- cpu_head_offset = tp_rank * gpu_heads * gpu_head_dim
96+ cpu_heads = cpu_kv_cache .shape [3 ]
97+ factor = (tp_num * gpu_heads ) // cpu_heads
98+ cpu_head_offset = (tp_rank // factor ) * gpu_heads * gpu_head_dim
99+ if tp_rank % factor != 0 :
100+ # redundant kv does not need to offload
101+ return
98102
99103 grid = (page_num ,)
100104 num_warps = 4
@@ -142,7 +146,6 @@ def _load_cpu_cache_to_gpu(
142146 page_indexes_ptr ,
143147 layer_num ,
144148 head_all_dim ,
145- all_move_token_num ,
146149 cpu_head_offset ,
147150 BLOCK_HEAD_ALL_DIM : tl .constexpr ,
148151 TOKEN_BLOCK : tl .constexpr ,
@@ -152,17 +155,11 @@ def _load_cpu_cache_to_gpu(
152155 if cpu_page_index == - 1 :
153156 return
154157
155- gpu_stride0 = tl .cast (gpu_stride0 , dtype = tl .int64 )
156- padded_size = TOKEN_BLOCK * tl .num_programs (0 ) - all_move_token_num
157- head_all_dim_range = tl .arange (0 , BLOCK_HEAD_ALL_DIM )
158158 token_range = block_index * TOKEN_BLOCK + tl .arange (0 , TOKEN_BLOCK )
159- token_range = token_range - padded_size
160-
161- token_mask = token_range >= 0
159+ token_indexes = tl .load (token_indexes_ptr + token_range ).to (tl .int64 )
160+ head_all_dim_range = tl .arange (0 , BLOCK_HEAD_ALL_DIM )
162161 head_dim_mask = head_all_dim_range < head_all_dim
163162
164- token_indexes = tl .load (token_indexes_ptr + token_range , mask = token_mask , other = 0 ).to (tl .int64 )
165-
166163 cpu_page_index = tl .load (page_indexes_ptr + block_index ).to (tl .int64 )
167164 for layer_index in range (layer_num ):
168165 cpu_ptr = (
@@ -176,14 +173,14 @@ def _load_cpu_cache_to_gpu(
176173
177174 gpu_ptr = (
178175 gpu_kv_cache_ptr
179- + layer_index * gpu_stride0
176+ + layer_index . to ( tl . int64 ) * gpu_stride0
180177 + token_indexes [:, None ] * gpu_stride1
181178 + head_all_dim_range [None , :]
182179 )
183180 tl .store (
184181 gpu_ptr ,
185182 cpu_data ,
186- mask = token_mask [:, None ] & head_dim_mask [None , :],
183+ mask = head_dim_mask [None , :],
187184 )
188185 return
189186
@@ -196,27 +193,28 @@ def load_cpu_kv_to_gpu(
196193 page_indexes : torch .Tensor ,
197194):
198195 """
199- this function is used to load CPU KV cache to GPU KV cache.
200- Supports tensor parallelism (TP > 1).
196+ this function is used to offload GPU KV cache to CPU KV cache.
201197 Args:
202198 mem_indexes: (token_num,)
203199 gpu_kv_cache: (layer_num, token_num, head_num, head_dim)
204200 cpu_kv_cache: (page_num, layer_num, token_block_size, head_num, head_dim)
205201 page_indexes: (page_num,)
206202 """
207- from lightllm .utils .dist_utils import get_current_rank_in_dp , get_dp_world_size
208-
209203 token_block_size = cpu_kv_cache .shape [2 ]
210204 token_num = page_indexes .shape [0 ] * token_block_size
211205 assert mem_indexes .shape [0 ] >= token_num
212206 page_num = page_indexes .shape [0 ]
207+ assert len (mem_indexes ) == page_num * token_block_size
213208 BLOCK_HEAD_ALL_DIM = triton .next_power_of_2 (gpu_kv_cache .shape [- 1 ] * gpu_kv_cache .shape [- 2 ])
214209
215210 # Calculate head offset for tensor parallelism
216211 tp_rank = get_current_rank_in_dp ()
212+ tp_num = get_dp_world_size ()
217213 gpu_heads = gpu_kv_cache .shape [2 ]
218214 gpu_head_dim = gpu_kv_cache .shape [3 ]
219- cpu_head_offset = tp_rank * gpu_heads * gpu_head_dim
215+ cpu_heads = cpu_kv_cache .shape [3 ]
216+ factor = (tp_num * gpu_heads ) // cpu_heads
217+ cpu_head_offset = (tp_rank // factor ) * gpu_heads * gpu_head_dim
220218
221219 grid = (page_num ,)
222220 num_warps = 1
@@ -237,7 +235,6 @@ def load_cpu_kv_to_gpu(
237235 page_indexes_ptr = page_indexes ,
238236 layer_num = gpu_kv_cache .shape [0 ],
239237 head_all_dim = gpu_kv_cache .shape [- 1 ] * gpu_kv_cache .shape [- 2 ],
240- all_move_token_num = len (mem_indexes ),
241238 cpu_head_offset = cpu_head_offset ,
242239 BLOCK_HEAD_ALL_DIM = BLOCK_HEAD_ALL_DIM ,
243240 TOKEN_BLOCK = token_block_size ,
0 commit comments