@@ -309,7 +309,8 @@ def __post_init__(self):
309309 [self .max_num_sequences , self .kv_cache_manager .max_blocks_per_seq ],
310310 cache_name = "indexer_k_cache_block_offsets" ,
311311 dtype = torch .int32 ,
312- capture_graph = capture_graph )
312+ capture_graph = capture_graph ,
313+ )
313314 self .host_indexer_k_cache_block_offsets = torch .zeros_like (
314315 self .indexer_k_cache_block_offsets ,
315316 device = 'cpu' ,
@@ -319,96 +320,117 @@ def __post_init__(self):
319320 # For mla_rope_append_paged_kv_assign_q
320321 if not self .enable_context_mla_with_cached_kv :
321322 self .ctx_cached_token_indptr = self .get_empty (
322- self .cuda_graph_buffers , (self .max_num_requests + 1 , ),
323+ self .cuda_graph_buffers ,
324+ (self .max_num_requests + 1 , ),
323325 cache_name = "ctx_cached_token_indptr" ,
324326 dtype = torch .int64 ,
325- capture_graph = capture_graph )
327+ capture_graph = capture_graph ,
328+ )
326329 self .host_ctx_cached_token_indptr = torch .zeros_like (
327330 self .ctx_cached_token_indptr ,
328331 device = 'cpu' ,
329332 pin_memory = True ,
330333 )
331- self .ctx_kv_indptr = self .get_empty (self .cuda_graph_buffers ,
332- (self .max_num_requests + 1 , ),
333- cache_name = "ctx_kv_indptr" ,
334- dtype = torch .int64 ,
335- capture_graph = capture_graph )
334+ self .ctx_kv_indptr = self .get_empty (
335+ self .cuda_graph_buffers ,
336+ (self .max_num_requests + 1 , ),
337+ cache_name = "ctx_kv_indptr" ,
338+ dtype = torch .int64 ,
339+ capture_graph = capture_graph ,
340+ )
336341 self .host_ctx_kv_indptr = torch .zeros_like (
337342 self .ctx_kv_indptr ,
338343 device = 'cpu' ,
339344 pin_memory = True ,
340345 )
341346 # New generation buffers for dsa
342347 self .gen_cached_token_indptr = self .get_empty (
343- self .cuda_graph_buffers , (self .max_num_requests + 1 , ),
348+ self .cuda_graph_buffers ,
349+ (self .max_num_requests + 1 , ),
344350 cache_name = "gen_cached_token_indptr" ,
345351 dtype = torch .int64 ,
346- capture_graph = capture_graph )
352+ capture_graph = capture_graph ,
353+ )
347354 self .host_gen_cached_token_indptr = torch .zeros_like (
348355 self .gen_cached_token_indptr ,
349356 device = 'cpu' ,
350357 pin_memory = True ,
351358 )
352- self .gen_kv_indptr = self .get_empty (self .cuda_graph_buffers ,
353- (self .max_num_requests + 1 , ),
354- cache_name = "gen_kv_indptr" ,
355- dtype = torch .int64 ,
356- capture_graph = capture_graph )
359+ self .gen_kv_indptr = self .get_empty (
360+ self .cuda_graph_buffers ,
361+ (self .max_num_requests + 1 , ),
362+ cache_name = "gen_kv_indptr" ,
363+ dtype = torch .int64 ,
364+ capture_graph = capture_graph ,
365+ )
357366 self .host_gen_kv_indptr = torch .zeros_like (
358367 self .gen_kv_indptr ,
359368 device = 'cpu' ,
360369 pin_memory = True ,
361370 )
362371 # Indexer metadata
363372 # Separate slot mappings for non-interleaved layout (flat byte indices)
364- self .slot_mapping_fp8 = self .get_empty (self .cuda_graph_buffers ,
365- (self .max_num_tokens , ),
366- cache_name = "slot_mapping_fp8" ,
367- dtype = torch .int64 ,
368- capture_graph = capture_graph )
373+ self .slot_mapping_fp8 = self .get_empty (
374+ self .cuda_graph_buffers ,
375+ (self .max_num_tokens , ),
376+ cache_name = "slot_mapping_fp8" ,
377+ dtype = torch .int64 ,
378+ capture_graph = capture_graph ,
379+ )
369380 self .host_slot_mapping_fp8 = torch .zeros_like (
370381 self .slot_mapping_fp8 ,
371382 device = 'cpu' ,
372383 pin_memory = True ,
373384 )
374385 self .slot_mapping_scale = self .get_empty (
375- self .cuda_graph_buffers , (self .max_num_tokens , ),
386+ self .cuda_graph_buffers ,
387+ (self .max_num_tokens , ),
376388 cache_name = "slot_mapping_scale" ,
377389 dtype = torch .int64 ,
378- capture_graph = capture_graph )
390+ capture_graph = capture_graph ,
391+ )
379392 self .host_slot_mapping_scale = torch .zeros_like (
380393 self .slot_mapping_scale ,
381394 device = 'cpu' ,
382395 pin_memory = True ,
383396 )
384397 # Per-token request index buffer for topk_indices conversion
385- self .req_idx_per_token = self .get_empty (self .cuda_graph_buffers ,
386- (self .max_num_tokens , ),
387- cache_name = "req_idx_per_token" ,
388- dtype = torch .int32 ,
389- capture_graph = capture_graph )
398+ self .req_idx_per_token = self .get_empty (
399+ self .cuda_graph_buffers ,
400+ (self .max_num_tokens , ),
401+ cache_name = "req_idx_per_token" ,
402+ dtype = torch .int32 ,
403+ capture_graph = capture_graph ,
404+ )
390405 # Block table for topk_indices conversion (shared for context and generation)
391406 self .block_table = self .get_empty (
392407 self .cuda_graph_buffers ,
393408 (self .max_num_requests , self .kv_cache_manager .max_blocks_per_seq ),
394409 cache_name = "block_table" ,
395410 dtype = torch .int32 ,
396- capture_graph = capture_graph )
411+ capture_graph = capture_graph ,
412+ )
397413 self .scheduler_metadata_buffer = self .get_empty (
398- self .cuda_graph_buffers , (self .num_sms + 1 , 2 ),
414+ self .cuda_graph_buffers ,
415+ (self .num_sms + 1 , 2 ),
399416 cache_name = "scheduler_metadata_buffer" ,
400417 dtype = torch .int32 ,
401- capture_graph = capture_graph )
402- self .cu_seqlen_ks = self .get_empty (self .cuda_graph_buffers ,
403- (self .max_num_tokens , ),
404- cache_name = "cu_seqlen_ks" ,
405- dtype = torch .int32 ,
406- capture_graph = capture_graph )
407- self .cu_seqlen_ke = self .get_empty (self .cuda_graph_buffers ,
408- (self .max_num_tokens , ),
409- cache_name = "cu_seqlen_ke" ,
410- dtype = torch .int32 ,
411- capture_graph = capture_graph )
418+ capture_graph = capture_graph ,
419+ )
420+ self .cu_seqlen_ks = self .get_empty (
421+ self .cuda_graph_buffers ,
422+ (self .max_num_tokens , ),
423+ cache_name = "cu_seqlen_ks" ,
424+ dtype = torch .int32 ,
425+ capture_graph = capture_graph ,
426+ )
427+ self .cu_seqlen_ke = self .get_empty (
428+ self .cuda_graph_buffers ,
429+ (self .max_num_tokens , ),
430+ cache_name = "cu_seqlen_ke" ,
431+ dtype = torch .int32 ,
432+ capture_graph = capture_graph ,
433+ )
412434
413435 def prepare (self ):
414436 super ().prepare ()
0 commit comments