22
33import triton
44import triton .language as tl
5- from lightllm .utils .dist_utils import get_current_rank_in_dp , get_dp_world_size
65
76
87@triton .jit
@@ -22,9 +21,16 @@ def _offload_gpu_kv_to_cpu(
2221 page_indexes_ptr ,
2322 page_readies_ptr ,
2423 layer_num ,
25- head_all_dim ,
26- cpu_head_offset ,
27- BLOCK_HEAD_ALL_DIM : tl .constexpr ,
24+ head_dim ,
25+ cpu_k_start_head_index : tl .constexpr ,
26+ cpu_k_head_num : tl .constexpr ,
27+ gpu_k_start_head_index : tl .constexpr ,
28+ gpu_k_head_num : tl .constexpr ,
29+ cpu_v_start_head_index : tl .constexpr ,
30+ cpu_v_head_num : tl .constexpr ,
31+ gpu_v_start_head_index : tl .constexpr ,
32+ gpu_v_head_num : tl .constexpr ,
33+ BLOCK_HEAD_DIM : tl .constexpr ,
2834 TOKEN_BLOCK : tl .constexpr ,
2935):
3036 block_index = tl .program_id (0 )
@@ -38,28 +44,61 @@ def _offload_gpu_kv_to_cpu(
3844
3945 token_range = block_index * TOKEN_BLOCK + tl .arange (0 , TOKEN_BLOCK )
4046 token_indexes = tl .load (token_indexes_ptr + token_range ).to (tl .int64 )
41- head_all_dim_range = tl .arange (0 , BLOCK_HEAD_ALL_DIM )
47+ head_dim_range = tl .arange (0 , BLOCK_HEAD_DIM )
48+ head_dim_mask = head_dim_range < head_dim
4249
4350 for layer_index in range (layer_num ):
44- gpu_ptr = (
45- gpu_kv_cache_ptr
46- + layer_index .to (tl .int64 ) * gpu_stride0
47- + token_indexes [:, None ] * gpu_stride1
48- + head_all_dim_range [None , :]
49- )
50- gpu_data = tl .load (gpu_ptr , mask = (head_all_dim_range [None , :] < head_all_dim ), other = 0.0 )
51- cpu_ptr = (
52- cpu_kv_cache_ptr
53- + cpu_page_index * cpu_stride0
54- + layer_index * cpu_stride1
55- + tl .arange (0 , TOKEN_BLOCK )[:, None ] * cpu_stride2
56- + (cpu_head_offset + head_all_dim_range [None , :])
57- )
58- tl .store (
59- cpu_ptr ,
60- gpu_data ,
61- mask = (head_all_dim_range [None , :] < head_all_dim ),
62- )
51+ for k_head_index in range (gpu_k_head_num ):
52+ gpu_k_head_index = k_head_index + gpu_k_start_head_index
53+ cpu_k_head_index = k_head_index + cpu_k_start_head_index
54+
55+ gpu_ptr = (
56+ gpu_kv_cache_ptr
57+ + layer_index .to (tl .int64 ) * gpu_stride0
58+ + token_indexes [:, None ] * gpu_stride1
59+ + gpu_k_head_index .to (tl .int64 ) * gpu_stride2
60+ + head_dim_range [None , :]
61+ )
62+ gpu_data = tl .load (gpu_ptr , mask = head_dim_mask [None , :], other = 0.0 )
63+ cpu_ptr = (
64+ cpu_kv_cache_ptr
65+ + cpu_page_index * cpu_stride0
66+ + layer_index * cpu_stride1
67+ + tl .arange (0 , TOKEN_BLOCK )[:, None ] * cpu_stride2
68+ + cpu_k_head_index * cpu_stride3
69+ + head_dim_range [None , :]
70+ )
71+ tl .store (
72+ cpu_ptr ,
73+ gpu_data ,
74+ mask = head_dim_mask [None , :],
75+ )
76+
77+ for v_head_index in range (gpu_v_head_num ):
78+ gpu_v_head_index = v_head_index + gpu_v_start_head_index
79+ cpu_v_head_index = v_head_index + cpu_v_start_head_index
80+
81+ gpu_ptr = (
82+ gpu_kv_cache_ptr
83+ + layer_index .to (tl .int64 ) * gpu_stride0
84+ + token_indexes [:, None ] * gpu_stride1
85+ + gpu_v_head_index .to (tl .int64 ) * gpu_stride2
86+ + head_dim_range [None , :]
87+ )
88+ gpu_data = tl .load (gpu_ptr , mask = head_dim_mask [None , :], other = 0.0 )
89+ cpu_ptr = (
90+ cpu_kv_cache_ptr
91+ + cpu_page_index * cpu_stride0
92+ + layer_index * cpu_stride1
93+ + tl .arange (0 , TOKEN_BLOCK )[:, None ] * cpu_stride2
94+ + cpu_v_head_index * cpu_stride3
95+ + head_dim_range [None , :]
96+ )
97+ tl .store (
98+ cpu_ptr ,
99+ gpu_data ,
100+ mask = head_dim_mask [None , :],
101+ )
63102 return
64103
65104
@@ -70,6 +109,9 @@ def offload_gpu_kv_to_cpu(
70109 cpu_kv_cache : torch .Tensor ,
71110 page_indexes : torch .Tensor ,
72111 page_readies : torch .Tensor ,
112+ tp_index : int ,
113+ tp_world_size : int ,
114+ _cache_data = {},
73115):
74116 """
75117 this function is used to offload GPU KV cache to CPU KV cache.
@@ -81,25 +123,108 @@ def offload_gpu_kv_to_cpu(
81123 page_readies: (page_num,)
82124 """
83125 token_block_size = cpu_kv_cache .shape [2 ]
84- token_num = page_indexes .shape [0 ] * token_block_size
85- assert token_indexes .shape [0 ] >= token_num
126+ token_num = token_indexes .shape [0 ]
127+ assert token_num == page_indexes .shape [0 ] * token_block_size
86128 assert page_indexes .shape == page_readies .shape
87- page_num = page_indexes .shape [0 ]
88- head_all_dim = gpu_kv_cache .shape [- 1 ] * gpu_kv_cache .shape [- 2 ]
89- BLOCK_HEAD_ALL_DIM = triton .next_power_of_2 (gpu_kv_cache .shape [- 1 ] * gpu_kv_cache .shape [- 2 ])
90129
91- # Calculate head offset for tensor parallelism
92- tp_rank = get_current_rank_in_dp ()
93- tp_num = get_dp_world_size ()
94130 gpu_heads = gpu_kv_cache .shape [2 ]
95131 gpu_head_dim = gpu_kv_cache .shape [3 ]
96132 cpu_heads = cpu_kv_cache .shape [3 ]
97- factor = (tp_num * gpu_heads ) // cpu_heads
98- cpu_head_offset = (tp_rank // factor ) * gpu_heads * gpu_head_dim
99- if tp_rank % factor != 0 :
100- # redundant kv does not need to offload
133+ cpu_head_dim = cpu_kv_cache .shape [4 ]
134+ assert gpu_head_dim == cpu_head_dim
135+ head_dim = gpu_head_dim
136+ scale_size = (tp_world_size * gpu_heads ) // cpu_heads
137+
138+ # 计算需要拷贝的 head 索引的对应关系
139+ if (gpu_heads , cpu_heads , tp_index , tp_world_size ) in _cache_data :
140+ need_offload , head_info_tuple = _cache_data [(gpu_heads , cpu_heads , tp_index , tp_world_size )]
141+ else :
142+ if cpu_heads > 1 :
143+ assert (tp_world_size * gpu_heads ) % cpu_heads == 0
144+ assert cpu_heads % 2 == 0
145+
146+ cpu_heads_index = (
147+ torch .arange (0 , cpu_heads , device = "cpu" , dtype = torch .int32 )
148+ .view (cpu_heads , 1 )
149+ .tile ((1 , scale_size ))
150+ .view (2 , tp_world_size , - 1 )
151+ )
152+ # k
153+ k_cpu_heads_index = cpu_heads_index [0 ][tp_index ]
154+ # v
155+ v_cpu_heads_index = cpu_heads_index [1 ][tp_index ]
156+
157+ cpu_heads_index = torch .cat ([k_cpu_heads_index , v_cpu_heads_index ], dim = 0 ).view (2 , - 1 ).numpy ()
158+ gpu_heads_index = torch .arange (0 , gpu_heads , device = "cpu" , dtype = torch .int32 ).view (2 , - 1 ).numpy ()
159+
160+ need_offload = tp_index % scale_size == 0
161+
162+ cpu_k_start_head_index = cpu_heads_index [0 , 0 ]
163+ cpu_k_head_num = len (cpu_heads_index [0 ])
164+ gpu_k_start_head_index = gpu_heads_index [0 , 0 ]
165+ gpu_k_head_num = len (gpu_heads_index [0 ])
166+ assert cpu_k_head_num == gpu_k_head_num
167+ cpu_v_start_head_index = cpu_heads_index [1 , 0 ]
168+ cpu_v_head_num = len (cpu_heads_index [1 ])
169+ gpu_v_start_head_index = gpu_heads_index [1 , 0 ]
170+ gpu_v_head_num = len (gpu_heads_index [1 ])
171+ assert cpu_v_head_num == gpu_v_head_num
172+
173+ head_info_tuple = (
174+ cpu_k_start_head_index ,
175+ cpu_k_head_num ,
176+ gpu_k_start_head_index ,
177+ gpu_k_head_num ,
178+ cpu_v_start_head_index ,
179+ cpu_v_head_num ,
180+ gpu_v_start_head_index ,
181+ gpu_v_head_num ,
182+ )
183+
184+ else :
185+ assert gpu_heads == 1
186+ assert cpu_heads == 1
187+
188+ need_offload = tp_index == 0
189+
190+ cpu_k_start_head_index = 0
191+ cpu_k_head_num = 1
192+ gpu_k_start_head_index = 0
193+ gpu_k_head_num = 1
194+ cpu_v_start_head_index = 0
195+ cpu_v_head_num = 0
196+ gpu_v_start_head_index = 0
197+ gpu_v_head_num = 0
198+ head_info_tuple = (
199+ cpu_k_start_head_index ,
200+ cpu_k_head_num ,
201+ gpu_k_start_head_index ,
202+ gpu_k_head_num ,
203+ cpu_v_start_head_index ,
204+ cpu_v_head_num ,
205+ gpu_v_start_head_index ,
206+ gpu_v_head_num ,
207+ )
208+
209+ _cache_data [(gpu_heads , cpu_heads , tp_index , tp_world_size )] = (need_offload , head_info_tuple )
210+
211+ (
212+ cpu_k_start_head_index ,
213+ cpu_k_head_num ,
214+ gpu_k_start_head_index ,
215+ gpu_k_head_num ,
216+ cpu_v_start_head_index ,
217+ cpu_v_head_num ,
218+ gpu_v_start_head_index ,
219+ gpu_v_head_num ,
220+ ) = head_info_tuple
221+
222+ if not need_offload :
101223 return
102224
225+ assert token_block_size == triton .next_power_of_2 (token_block_size )
226+ page_num = page_indexes .shape [0 ]
227+
103228 grid = (page_num ,)
104229 num_warps = 4
105230
@@ -119,9 +244,16 @@ def offload_gpu_kv_to_cpu(
119244 page_indexes_ptr = page_indexes ,
120245 page_readies_ptr = page_readies ,
121246 layer_num = gpu_kv_cache .shape [0 ],
122- head_all_dim = head_all_dim ,
123- cpu_head_offset = cpu_head_offset ,
124- BLOCK_HEAD_ALL_DIM = BLOCK_HEAD_ALL_DIM ,
247+ head_dim = head_dim ,
248+ cpu_k_start_head_index = cpu_v_start_head_index ,
249+ cpu_k_head_num = cpu_k_head_num ,
250+ gpu_k_start_head_index = gpu_k_start_head_index ,
251+ gpu_k_head_num = gpu_k_head_num ,
252+ cpu_v_start_head_index = cpu_v_start_head_index ,
253+ cpu_v_head_num = cpu_v_head_num ,
254+ gpu_v_start_head_index = gpu_v_start_head_index ,
255+ gpu_v_head_num = gpu_v_head_num ,
256+ BLOCK_HEAD_DIM = triton .next_power_of_2 (head_dim ),
125257 TOKEN_BLOCK = token_block_size ,
126258 num_warps = num_warps ,
127259 num_stages = 1 ,
0 commit comments