@@ -75,10 +75,24 @@ def __post_init__(self):
7575 dtype = torch .int32 )
7676
7777 # Context length of RocketKV key for each valid sequence
78- self .k_context_lens = torch . empty (
79- self .max_num_sequences ,
80- device = 'cpu' ,
78+ self .k_context_lens_cuda = self . get_empty (
79+ self .cuda_graph_buffers ,
80+ ( self . max_num_sequences , ) ,
8181 dtype = torch .int32 ,
82+ cache_name = "k_context_lens_cuda" ,
83+ capture_graph = capture_graph ,
84+ )
85+ self .k_context_lens = torch .zeros_like (self .k_context_lens_cuda ,
86+ device = 'cpu' ,
87+ dtype = torch .int32 )
88+
89+ # Start index of RocketKV key for each valid sequence
90+ self .k_context_start_cuda = self .get_empty (
91+ None ,
92+ (self .max_num_sequences , ),
93+ dtype = torch .int32 ,
94+ cache_name = "k_context_start_cuda" ,
95+ capture_graph = capture_graph ,
8296 )
8397
8498 # Cumulative context lengths for each sequence
@@ -231,6 +245,8 @@ def prepare(self):
231245 # Only consider sequences that are long enough for sparse kv indices prediction in context phase
232246 self .k_context_lens [:valid_batch_size ] = self .prompt_lens_cpu [
233247 valid_seq_indices ] - self .window_size
248+ self .k_context_lens_cuda [:valid_batch_size ].copy_ (
249+ self .k_context_lens [:valid_batch_size ], non_blocking = True )
234250
235251 sparse_counts_ctx = torch .zeros (self .num_contexts ,
236252 dtype = torch .int32 ,
@@ -399,12 +415,32 @@ def sparse_kv_predict(
399415 padding = self .kernel_size // 2 ,
400416 stride = 1 )
401417
402- selected_prefix_indices = scores .topk (
403- self .prompt_budget - self .window_size ,
404- dim = - 1 ).indices .sort ().values .to (torch .int32 )
418+ # Use indexer topk prefill to select topk prefix indices
419+ total_tasks = metadata .valid_batch_size * self .num_kv_heads
420+
421+ selected_prefix_indices = torch .empty (
422+ (total_tasks , self .prompt_budget - self .window_size ),
423+ device = qkv_input .device ,
424+ dtype = torch .int32 )
425+
426+ scores = scores .view (total_tasks , - 1 )
427+
428+ row_starts = metadata .k_context_start_cuda [:metadata .
429+ valid_batch_size ].repeat_interleave (
430+ self .num_kv_heads )
431+ row_ends = metadata .k_context_lens_cuda [:metadata .
432+ valid_batch_size ].repeat_interleave (
433+ self .num_kv_heads )
434+ torch .ops .trtllm .indexer_topk_prefill (
435+ scores , row_starts , row_ends , selected_prefix_indices ,
436+ self .prompt_budget - self .window_size )
437+
438+ # Sort selected prefix indices to keep topk indices in ascending order
439+ selected_prefix_indices = torch .sort (selected_prefix_indices ,
440+ dim = - 1 ).values
405441 else :
406442 selected_prefix_indices = torch .empty (
407- (0 , self .num_kv_heads , self . prompt_budget - self .window_size ),
443+ (0 , self .prompt_budget - self .window_size ),
408444 device = qkv_input .device ,
409445 dtype = torch .int32 )
410446
@@ -416,7 +452,7 @@ def sparse_kv_predict(
416452 selected_prefix_indices , metadata .prompt_lens_cuda ,
417453 metadata .valid_seq_indices_cuda , sparse_kv_offsets ,
418454 metadata .num_contexts , metadata .total_sparse_ctx_indices ,
419- self .window_size , self .prompt_budget )
455+ self .window_size , self .prompt_budget , self . num_kv_heads )
420456
421457 # Update KT cache
422458 kt_cache_tensor = metadata .kv_cache_manager .get_kt_buffers (
0 commit comments