22
33import triton
44import triton .language as tl
5+ from lightllm .utils .dist_utils import get_current_rank_in_dp , get_dp_world_size
56
67
78@triton .jit
@@ -22,6 +23,7 @@ def _offload_gpu_kv_to_cpu(
2223 page_readies_ptr ,
2324 layer_num ,
2425 head_all_dim ,
26+ cpu_head_offset ,
2527 BLOCK_HEAD_ALL_DIM : tl .constexpr ,
2628 TOKEN_BLOCK : tl .constexpr ,
2729):
@@ -38,12 +40,10 @@ def _offload_gpu_kv_to_cpu(
3840 token_indexes = tl .load (token_indexes_ptr + token_range ).to (tl .int64 )
3941 head_all_dim_range = tl .arange (0 , BLOCK_HEAD_ALL_DIM )
4042
41- gpu_stride0 = tl .cast (gpu_stride0 , dtype = tl .int64 )
42-
4343 for layer_index in range (layer_num ):
4444 gpu_ptr = (
4545 gpu_kv_cache_ptr
46- + layer_index * gpu_stride0
46+ + layer_index . to ( tl . int64 ) * gpu_stride0
4747 + token_indexes [:, None ] * gpu_stride1
4848 + head_all_dim_range [None , :]
4949 )
@@ -53,7 +53,7 @@ def _offload_gpu_kv_to_cpu(
5353 + cpu_page_index * cpu_stride0
5454 + layer_index * cpu_stride1
5555 + tl .arange (0 , TOKEN_BLOCK )[:, None ] * cpu_stride2
56- + head_all_dim_range [None , :]
56+ + ( cpu_head_offset + head_all_dim_range [None , :])
5757 )
5858 tl .store (
5959 cpu_ptr ,
@@ -88,6 +88,18 @@ def offload_gpu_kv_to_cpu(
8888 head_all_dim = gpu_kv_cache .shape [- 1 ] * gpu_kv_cache .shape [- 2 ]
8989 BLOCK_HEAD_ALL_DIM = triton .next_power_of_2 (gpu_kv_cache .shape [- 1 ] * gpu_kv_cache .shape [- 2 ])
9090
91+ # Calculate head offset for tensor parallelism
92+ tp_rank = get_current_rank_in_dp ()
93+ tp_num = get_dp_world_size ()
94+ gpu_heads = gpu_kv_cache .shape [2 ]
95+ gpu_head_dim = gpu_kv_cache .shape [3 ]
96+ cpu_heads = cpu_kv_cache .shape [3 ]
97+ factor = (tp_num * gpu_heads ) // cpu_heads
98+ cpu_head_offset = (tp_rank // factor ) * gpu_heads * gpu_head_dim
99+ if tp_rank % factor != 0 :
100+ # redundant kv does not need to offload
101+ return
102+
91103 grid = (page_num ,)
92104 num_warps = 4
93105
@@ -108,6 +120,7 @@ def offload_gpu_kv_to_cpu(
108120 page_readies_ptr = page_readies ,
109121 layer_num = gpu_kv_cache .shape [0 ],
110122 head_all_dim = head_all_dim ,
123+ cpu_head_offset = cpu_head_offset ,
111124 BLOCK_HEAD_ALL_DIM = BLOCK_HEAD_ALL_DIM ,
112125 TOKEN_BLOCK = token_block_size ,
113126 num_warps = num_warps ,
@@ -133,7 +146,7 @@ def _load_cpu_cache_to_gpu(
133146 page_indexes_ptr ,
134147 layer_num ,
135148 head_all_dim ,
136- all_move_token_num ,
149+ cpu_head_offset ,
137150 BLOCK_HEAD_ALL_DIM : tl .constexpr ,
138151 TOKEN_BLOCK : tl .constexpr ,
139152):
@@ -142,38 +155,32 @@ def _load_cpu_cache_to_gpu(
142155 if cpu_page_index == - 1 :
143156 return
144157
145- gpu_stride0 = tl .cast (gpu_stride0 , dtype = tl .int64 )
146- padded_size = TOKEN_BLOCK * tl .num_programs (0 ) - all_move_token_num
147- head_all_dim_range = tl .arange (0 , BLOCK_HEAD_ALL_DIM )
148158 token_range = block_index * TOKEN_BLOCK + tl .arange (0 , TOKEN_BLOCK )
149- token_range = token_range - padded_size
150-
151- token_mask = token_range >= 0
159+ token_indexes = tl .load (token_indexes_ptr + token_range ).to (tl .int64 )
160+ head_all_dim_range = tl .arange (0 , BLOCK_HEAD_ALL_DIM )
152161 head_dim_mask = head_all_dim_range < head_all_dim
153162
154- token_indexes = tl .load (token_indexes_ptr + token_range , mask = token_mask , other = 0 ).to (tl .int64 )
155-
156- cpu_page_index = tl .load (page_indexes_ptr + block_index )
163+ cpu_page_index = tl .load (page_indexes_ptr + block_index ).to (tl .int64 )
157164 for layer_index in range (layer_num ):
158165 cpu_ptr = (
159166 cpu_kv_cache_ptr
160167 + cpu_page_index * cpu_stride0
161168 + layer_index * cpu_stride1
162169 + tl .arange (0 , TOKEN_BLOCK )[:, None ] * cpu_stride2
163- + head_all_dim_range [None , :]
170+ + ( cpu_head_offset + head_all_dim_range [None , :])
164171 )
165172 cpu_data = tl .load (cpu_ptr , mask = head_dim_mask [None , :], other = 0.0 )
166173
167174 gpu_ptr = (
168175 gpu_kv_cache_ptr
169- + layer_index * gpu_stride0
176+ + layer_index . to ( tl . int64 ) * gpu_stride0
170177 + token_indexes [:, None ] * gpu_stride1
171178 + head_all_dim_range [None , :]
172179 )
173180 tl .store (
174181 gpu_ptr ,
175182 cpu_data ,
176- mask = token_mask [:, None ] & head_dim_mask [None , :],
183+ mask = head_dim_mask [None , :],
177184 )
178185 return
179186
@@ -197,12 +204,22 @@ def load_cpu_kv_to_gpu(
197204 token_num = page_indexes .shape [0 ] * token_block_size
198205 assert mem_indexes .shape [0 ] >= token_num
199206 page_num = page_indexes .shape [0 ]
207+ assert len (mem_indexes ) == page_num * token_block_size
200208 BLOCK_HEAD_ALL_DIM = triton .next_power_of_2 (gpu_kv_cache .shape [- 1 ] * gpu_kv_cache .shape [- 2 ])
201209
210+ # Calculate head offset for tensor parallelism
211+ tp_rank = get_current_rank_in_dp ()
212+ tp_num = get_dp_world_size ()
213+ gpu_heads = gpu_kv_cache .shape [2 ]
214+ gpu_head_dim = gpu_kv_cache .shape [3 ]
215+ cpu_heads = cpu_kv_cache .shape [3 ]
216+ factor = (tp_num * gpu_heads ) // cpu_heads
217+ cpu_head_offset = (tp_rank // factor ) * gpu_heads * gpu_head_dim
218+
202219 grid = (page_num ,)
203220 num_warps = 1
204221
205- _offload_gpu_kv_to_cpu [grid ](
222+ _load_cpu_cache_to_gpu [grid ](
206223 token_indexes_ptr = mem_indexes ,
207224 gpu_kv_cache_ptr = gpu_kv_cache ,
208225 gpu_stride0 = gpu_kv_cache .stride (0 ),
@@ -218,7 +235,7 @@ def load_cpu_kv_to_gpu(
218235 page_indexes_ptr = page_indexes ,
219236 layer_num = gpu_kv_cache .shape [0 ],
220237 head_all_dim = gpu_kv_cache .shape [- 1 ] * gpu_kv_cache .shape [- 2 ],
221- all_move_token_num = len ( mem_indexes ) ,
238+ cpu_head_offset = cpu_head_offset ,
222239 BLOCK_HEAD_ALL_DIM = BLOCK_HEAD_ALL_DIM ,
223240 TOKEN_BLOCK = token_block_size ,
224241 num_warps = num_warps ,
0 commit comments