@@ -304,7 +304,7 @@ def __post_init__(self):
304304
305305 capture_graph = torch .cuda .is_current_stream_capturing ()
306306
307- self .indexer_k_cache_block_offsets = get_empty (
307+ self .indexer_k_cache_block_offsets = self . get_empty (
308308 [self .max_num_sequences , self .kv_cache_manager .max_blocks_per_seq ],
309309 cache_name = "indexer_k_cache_block_offsets" ,
310310 dtype = torch .int32 ,
@@ -317,7 +317,7 @@ def __post_init__(self):
317317
318318 # For mla_rope_append_paged_kv_assign_q
319319 if not self .enable_context_mla_with_cached_kv :
320- self .ctx_cached_token_indptr = get_empty (
320+ self .ctx_cached_token_indptr = self . get_empty (
321321 (self .max_num_requests + 1 , ),
322322 cache_name = "ctx_cached_token_indptr" ,
323323 dtype = torch .int64 ,
@@ -327,17 +327,17 @@ def __post_init__(self):
327327 device = 'cpu' ,
328328 pin_memory = True ,
329329 )
330- self .ctx_kv_indptr = get_empty ((self .max_num_requests + 1 , ),
331- cache_name = "ctx_kv_indptr" ,
332- dtype = torch .int64 ,
333- capture_graph = capture_graph )
330+ self .ctx_kv_indptr = self . get_empty ((self .max_num_requests + 1 , ),
331+ cache_name = "ctx_kv_indptr" ,
332+ dtype = torch .int64 ,
333+ capture_graph = capture_graph )
334334 self .host_ctx_kv_indptr = torch .zeros_like (
335335 self .ctx_kv_indptr ,
336336 device = 'cpu' ,
337337 pin_memory = True ,
338338 )
339339 # New generation buffers for dsa
340- self .gen_cached_token_indptr = get_empty (
340+ self .gen_cached_token_indptr = self . get_empty (
341341 (self .max_num_requests + 1 , ),
342342 cache_name = "gen_cached_token_indptr" ,
343343 dtype = torch .int64 ,
@@ -347,59 +347,60 @@ def __post_init__(self):
347347 device = 'cpu' ,
348348 pin_memory = True ,
349349 )
350- self .gen_kv_indptr = get_empty ((self .max_num_requests + 1 , ),
351- cache_name = "gen_kv_indptr" ,
352- dtype = torch .int64 ,
353- capture_graph = capture_graph )
350+ self .gen_kv_indptr = self . get_empty ((self .max_num_requests + 1 , ),
351+ cache_name = "gen_kv_indptr" ,
352+ dtype = torch .int64 ,
353+ capture_graph = capture_graph )
354354 self .host_gen_kv_indptr = torch .zeros_like (
355355 self .gen_kv_indptr ,
356356 device = 'cpu' ,
357357 pin_memory = True ,
358358 )
359359 # Indexer metadata
360360 # Separate slot mappings for non-interleaved layout (flat byte indices)
361- self .slot_mapping_fp8 = get_empty ((self .max_num_tokens , ),
362- cache_name = "slot_mapping_fp8" ,
363- dtype = torch .int64 ,
364- capture_graph = capture_graph )
361+ self .slot_mapping_fp8 = self . get_empty ((self .max_num_tokens , ),
362+ cache_name = "slot_mapping_fp8" ,
363+ dtype = torch .int64 ,
364+ capture_graph = capture_graph )
365365 self .host_slot_mapping_fp8 = torch .zeros_like (
366366 self .slot_mapping_fp8 ,
367367 device = 'cpu' ,
368368 pin_memory = True ,
369369 )
370- self .slot_mapping_scale = get_empty ((self .max_num_tokens , ),
371- cache_name = "slot_mapping_scale" ,
372- dtype = torch .int64 ,
373- capture_graph = capture_graph )
370+ self .slot_mapping_scale = self .get_empty (
371+ (self .max_num_tokens , ),
372+ cache_name = "slot_mapping_scale" ,
373+ dtype = torch .int64 ,
374+ capture_graph = capture_graph )
374375 self .host_slot_mapping_scale = torch .zeros_like (
375376 self .slot_mapping_scale ,
376377 device = 'cpu' ,
377378 pin_memory = True ,
378379 )
379380 # Per-token request index buffer for topk_indices conversion
380- self .req_idx_per_token = get_empty ((self .max_num_tokens , ),
381- cache_name = "req_idx_per_token" ,
382- dtype = torch .int32 ,
383- capture_graph = capture_graph )
381+ self .req_idx_per_token = self . get_empty ((self .max_num_tokens , ),
382+ cache_name = "req_idx_per_token" ,
383+ dtype = torch .int32 ,
384+ capture_graph = capture_graph )
384385 # Block table for topk_indices conversion (shared for context and generation)
385- self .block_table = get_empty (
386+ self .block_table = self . get_empty (
386387 (self .max_num_requests , self .kv_cache_manager .max_blocks_per_seq ),
387388 cache_name = "block_table" ,
388389 dtype = torch .int32 ,
389390 capture_graph = capture_graph )
390- self .scheduler_metadata_buffer = get_empty (
391+ self .scheduler_metadata_buffer = self . get_empty (
391392 (self .num_sms + 1 , 2 ),
392393 cache_name = "scheduler_metadata_buffer" ,
393394 dtype = torch .int32 ,
394395 capture_graph = capture_graph )
395- self .cu_seqlen_ks = get_empty ((self .max_num_tokens , ),
396- cache_name = "cu_seqlen_ks" ,
397- dtype = torch .int32 ,
398- capture_graph = capture_graph )
399- self .cu_seqlen_ke = get_empty ((self .max_num_tokens , ),
400- cache_name = "cu_seqlen_ke" ,
401- dtype = torch .int32 ,
402- capture_graph = capture_graph )
396+ self .cu_seqlen_ks = self . get_empty ((self .max_num_tokens , ),
397+ cache_name = "cu_seqlen_ks" ,
398+ dtype = torch .int32 ,
399+ capture_graph = capture_graph )
400+ self .cu_seqlen_ke = self . get_empty ((self .max_num_tokens , ),
401+ cache_name = "cu_seqlen_ke" ,
402+ dtype = torch .int32 ,
403+ capture_graph = capture_graph )
403404
404405 def prepare (self ):
405406 super ().prepare ()
0 commit comments