@@ -68,3 +68,241 @@ def triton_index_gather(input, indices):
6868 dim_size ,
6969 BLOCK_SIZE = 1024 )
7070 return output
71+
72+
73+ @triton .jit
74+ def _update_kt_cache_ctx_kernel (k_ptr , cache_ptr , block_offsets_ptr ,
75+ cum_seq_lens_ptr , cum_kt_seq_lens_ptr ,
76+ token_to_batch_map_ptr , num_kv_heads , dim_size ,
77+ kt_page_size , tokens_per_block ,
78+ max_kt_blocks_per_seq ,
79+ BLOCK_SIZE : tl .constexpr ):
80+ # get program id
81+ kt_token_idx = tl .program_id (0 )
82+
83+ # get params
84+ batch_idx = tl .load (token_to_batch_map_ptr + kt_token_idx )
85+ kv_start_idx = tl .load (cum_seq_lens_ptr + batch_idx )
86+ kv_end_idx = tl .load (cum_seq_lens_ptr + batch_idx + 1 )
87+ kt_start_idx = tl .load (cum_kt_seq_lens_ptr + batch_idx )
88+ local_kt_token_idx = kt_token_idx - kt_start_idx
89+ global_kv_token_idx = kv_start_idx + local_kt_token_idx * kt_page_size
90+
91+ # get offsets
92+ hidden_size = num_kv_heads * dim_size
93+ k_base = k_ptr + global_kv_token_idx * hidden_size
94+ block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block
95+ block_idx = tl .load (block_offsets_ptr + block_offset )
96+ token_idx_in_block = local_kt_token_idx % tokens_per_block
97+ cache_base = cache_ptr + (block_idx * tokens_per_block +
98+ token_idx_in_block ) * hidden_size * 2
99+
100+ # compute min/max and store kt
101+ for hidden_start in tl .range (0 , hidden_size , BLOCK_SIZE ):
102+ hidden_indices = hidden_start + tl .arange (0 , BLOCK_SIZE )
103+ head_idx = hidden_indices // dim_size
104+ dim_idx = hidden_indices % dim_size
105+ dim_mask = hidden_indices < hidden_size
106+
107+ # get k_min and k_max
108+ k_min = tl .full ((BLOCK_SIZE , ), float ('inf' ), dtype = tl .float32 )
109+ k_max = tl .full ((BLOCK_SIZE , ), float ('-inf' ), dtype = tl .float32 )
110+ for i in range (kt_page_size ):
111+ if global_kv_token_idx + i < kv_end_idx :
112+ k = tl .load (k_base + i * hidden_size + hidden_indices ,
113+ mask = dim_mask ,
114+ other = 0.0 )
115+ k_min = tl .minimum (k_min , k )
116+ k_max = tl .maximum (k_max , k )
117+ k_min = k_min .to (cache_ptr .dtype .element_ty )
118+ k_max = k_max .to (cache_ptr .dtype .element_ty )
119+
120+ # store k_min and k_max to cache
121+ k_min_offset = cache_base + head_idx * dim_size * 2 + dim_idx
122+ k_max_offset = k_min_offset + dim_size
123+ tl .store (k_min_offset , k_min , mask = dim_mask )
124+ tl .store (k_max_offset , k_max , mask = dim_mask )
125+
126+
127+ @triton .jit
128+ def _update_kt_cache_gen_kernel (k_ptr , cache_ptr , block_offsets_ptr ,
129+ seq_lens_ptr , num_kv_heads , dim_size ,
130+ kt_page_size , tokens_per_block ,
131+ max_kt_blocks_per_seq ,
132+ BLOCK_SIZE : tl .constexpr ):
133+ # get program id
134+ batch_idx = tl .program_id (0 )
135+ head_idx = tl .program_id (1 )
136+
137+ # get params
138+ past_key_value_length = tl .load (seq_lens_ptr + batch_idx ) - 1
139+ kt_token_idx = past_key_value_length // kt_page_size
140+ kt_token_idx_in_page = past_key_value_length % kt_page_size
141+
142+ # get offsets
143+ hidden_size = num_kv_heads * dim_size
144+ k_base = k_ptr + batch_idx * hidden_size + head_idx * dim_size
145+ block_offset = batch_idx * max_kt_blocks_per_seq + kt_token_idx // tokens_per_block
146+ block_idx = tl .load (block_offsets_ptr + block_offset )
147+ kt_token_idx_in_block = kt_token_idx % tokens_per_block
148+ cache_base = cache_ptr + (block_idx * tokens_per_block +
149+ kt_token_idx_in_block ) * hidden_size * 2
150+ cache_base += head_idx * dim_size * 2
151+
152+ # update kt cache
153+ for hidden_start in tl .range (0 , dim_size , BLOCK_SIZE ):
154+ hidden_indices = hidden_start + tl .arange (0 , BLOCK_SIZE )
155+ dim_mask = hidden_indices < dim_size
156+
157+ # load k
158+ k = tl .load (k_base + hidden_indices , mask = dim_mask , other = 0.0 )
159+
160+ # load kt cache
161+ kt_mask = dim_mask & (kt_token_idx_in_page > 0 )
162+ k_min = tl .load (cache_base + hidden_indices ,
163+ mask = kt_mask ,
164+ other = float ('inf' ))
165+ k_max = tl .load (cache_base + hidden_indices + dim_size ,
166+ mask = kt_mask ,
167+ other = float ('-inf' ))
168+ k_min = tl .minimum (k_min , k )
169+ k_max = tl .maximum (k_max , k )
170+ k_min = k_min .to (cache_ptr .dtype .element_ty )
171+ k_max = k_max .to (cache_ptr .dtype .element_ty )
172+
173+ # store k_min and k_max to cache
174+ tl .store (cache_base + hidden_indices , k_min , mask = dim_mask )
175+ tl .store (cache_base + hidden_indices + dim_size , k_max , mask = dim_mask )
176+
177+
178+ @triton .jit
179+ def _load_kt_cache_kernel (kt_states_ptr , cache_ptr , block_offsets_ptr ,
180+ cum_kt_seq_lens_ptr , token_to_batch_map_ptr ,
181+ num_kv_heads , dim_size , tokens_per_block ,
182+ max_kt_blocks_per_seq , BLOCK_SIZE : tl .constexpr ):
183+ # get program id
184+ kt_token_idx = tl .program_id (0 )
185+
186+ # get params
187+ batch_idx = tl .load (token_to_batch_map_ptr + kt_token_idx )
188+ kt_start_idx = tl .load (cum_kt_seq_lens_ptr + batch_idx )
189+ local_kt_token_idx = kt_token_idx - kt_start_idx
190+
191+ # get offsets
192+ hidden_size = num_kv_heads * dim_size * 2
193+ kt_states_base = kt_states_ptr + kt_token_idx * hidden_size
194+ block_offset = batch_idx * max_kt_blocks_per_seq + local_kt_token_idx // tokens_per_block
195+ block_idx = tl .load (block_offsets_ptr + block_offset )
196+ token_idx_in_block = local_kt_token_idx % tokens_per_block
197+ cache_base = cache_ptr + (block_idx * tokens_per_block +
198+ token_idx_in_block ) * hidden_size
199+
200+ # load kt cache
201+ for hidden_start in tl .range (0 , hidden_size , BLOCK_SIZE ):
202+ hidden_indices = hidden_start + tl .arange (0 , BLOCK_SIZE )
203+ mask = hidden_indices < hidden_size
204+ # load kt cache
205+ kt = tl .load (cache_base + hidden_indices , mask = mask , other = 0.0 )
206+ # store kt to kt_states
207+ tl .store (kt_states_base + hidden_indices , kt , mask = mask )
208+
209+
210+ def triton_update_kt_cache (k ,
211+ kt_cache_tensor ,
212+ kt_cache_block_offsets ,
213+ seq_lens ,
214+ kt_page_size ,
215+ tokens_per_block ,
216+ max_kt_blocks_per_seq ,
217+ update = True ):
218+ # inputs:
219+ # k: (total_seq_len, num_kv_heads, head_dim)
220+ # kt_cache_tensor: (num_blocks, tokens_per_block, num_kv_heads, 2 * head_dim)
221+ # kt_cache_block_offsets: (max_batch_size, max_kt_blocks_per_seq)
222+ # seq_lens: (batch_size)
223+ # kt_page_size: int
224+ # update: bool
225+
226+ # outputs:
227+ # kt_states: (total_kt_tokens, num_kv_heads, 2 * head_dim)
228+
229+ # params
230+ batch_size = seq_lens .size (0 )
231+ num_kv_heads = k .size (1 )
232+ head_dim = k .size (2 )
233+ tokens_per_block = kt_cache_tensor .size (1 )
234+ num_kt_tokens = (seq_lens + kt_page_size - 1 ) // kt_page_size
235+
236+ # context
237+ if not update :
238+ total_num_kt_tokens = num_kt_tokens .sum ().item ()
239+ cum_seq_lens = torch .cumsum (torch .cat ([
240+ torch .zeros (1 , device = 'cuda' , dtype = torch .long ),
241+ seq_lens .to (torch .long )
242+ ]),
243+ dim = 0 )
244+ cum_kt_seq_lens = torch .cumsum (torch .cat ([
245+ torch .zeros (1 , device = 'cuda' , dtype = torch .long ),
246+ num_kt_tokens .to (torch .long )
247+ ]),
248+ dim = 0 )
249+
250+ token_to_batch_map = torch .repeat_interleave (
251+ torch .arange (batch_size ,
252+ device = 'cuda' ), repeats = num_kt_tokens ).to (torch .long )
253+ grid = (total_num_kt_tokens , )
254+ _update_kt_cache_ctx_kernel [grid ](k ,
255+ kt_cache_tensor ,
256+ kt_cache_block_offsets ,
257+ cum_seq_lens ,
258+ cum_kt_seq_lens ,
259+ token_to_batch_map ,
260+ num_kv_heads ,
261+ head_dim ,
262+ kt_page_size ,
263+ tokens_per_block ,
264+ max_kt_blocks_per_seq ,
265+ BLOCK_SIZE = 1024 )
266+ return
267+ else :
268+ # generation
269+ # update kt cache
270+ grid = (batch_size , num_kv_heads )
271+ _update_kt_cache_gen_kernel [grid ](k ,
272+ kt_cache_tensor ,
273+ kt_cache_block_offsets ,
274+ seq_lens ,
275+ num_kv_heads ,
276+ head_dim ,
277+ kt_page_size ,
278+ tokens_per_block ,
279+ max_kt_blocks_per_seq ,
280+ BLOCK_SIZE = 1024 )
281+
282+ # load kt cache
283+ total_num_kt_tokens = num_kt_tokens .sum ().item ()
284+ kt_states = torch .empty (
285+ (total_num_kt_tokens , num_kv_heads , 2 * head_dim ),
286+ device = 'cuda' ,
287+ dtype = k .dtype )
288+ token_to_batch_map = torch .repeat_interleave (
289+ torch .arange (batch_size ,
290+ device = 'cuda' ), repeats = num_kt_tokens ).to (torch .long )
291+ cum_kt_seq_lens = torch .cumsum (torch .cat ([
292+ torch .zeros (1 , device = 'cuda' , dtype = torch .long ),
293+ num_kt_tokens .to (torch .long )
294+ ]),
295+ dim = 0 )
296+ grid = (total_num_kt_tokens , )
297+ _load_kt_cache_kernel [grid ](kt_states ,
298+ kt_cache_tensor ,
299+ kt_cache_block_offsets ,
300+ cum_kt_seq_lens ,
301+ token_to_batch_map ,
302+ num_kv_heads ,
303+ head_dim ,
304+ tokens_per_block ,
305+ max_kt_blocks_per_seq ,
306+ BLOCK_SIZE = 1024 )
307+
308+ return kt_states
0 commit comments