@@ -286,7 +286,7 @@ def __post_init__(self):
286286
287287 capture_graph = torch .cuda .is_current_stream_capturing ()
288288
289- self .indexer_k_cache_block_offsets = get_empty (
289+ self .indexer_k_cache_block_offsets = self . get_empty (
290290 [self .max_num_sequences , self .kv_cache_manager .max_blocks_per_seq ],
291291 cache_name = "indexer_k_cache_block_offsets" ,
292292 dtype = torch .int32 ,
@@ -299,7 +299,7 @@ def __post_init__(self):
299299
300300 # For mla_rope_append_paged_kv_assign_q
301301 if not self .enable_context_mla_with_cached_kv :
302- self .ctx_cached_token_indptr = get_empty (
302+ self .ctx_cached_token_indptr = self . get_empty (
303303 (self .max_num_requests + 1 , ),
304304 cache_name = "ctx_cached_token_indptr" ,
305305 dtype = torch .int64 ,
@@ -309,17 +309,17 @@ def __post_init__(self):
309309 device = 'cpu' ,
310310 pin_memory = True ,
311311 )
312- self .ctx_kv_indptr = get_empty ((self .max_num_requests + 1 , ),
313- cache_name = "ctx_kv_indptr" ,
314- dtype = torch .int64 ,
315- capture_graph = capture_graph )
312+ self .ctx_kv_indptr = self . get_empty ((self .max_num_requests + 1 , ),
313+ cache_name = "ctx_kv_indptr" ,
314+ dtype = torch .int64 ,
315+ capture_graph = capture_graph )
316316 self .host_ctx_kv_indptr = torch .zeros_like (
317317 self .ctx_kv_indptr ,
318318 device = 'cpu' ,
319319 pin_memory = True ,
320320 )
321321 # New generation buffers for dsa
322- self .gen_cached_token_indptr = get_empty (
322+ self .gen_cached_token_indptr = self . get_empty (
323323 (self .max_num_requests + 1 , ),
324324 cache_name = "gen_cached_token_indptr" ,
325325 dtype = torch .int64 ,
@@ -329,59 +329,60 @@ def __post_init__(self):
329329 device = 'cpu' ,
330330 pin_memory = True ,
331331 )
332- self .gen_kv_indptr = get_empty ((self .max_num_requests + 1 , ),
333- cache_name = "gen_kv_indptr" ,
334- dtype = torch .int64 ,
335- capture_graph = capture_graph )
332+ self .gen_kv_indptr = self . get_empty ((self .max_num_requests + 1 , ),
333+ cache_name = "gen_kv_indptr" ,
334+ dtype = torch .int64 ,
335+ capture_graph = capture_graph )
336336 self .host_gen_kv_indptr = torch .zeros_like (
337337 self .gen_kv_indptr ,
338338 device = 'cpu' ,
339339 pin_memory = True ,
340340 )
341341 # Indexer metadata
342342 # Separate slot mappings for non-interleaved layout (flat byte indices)
343- self .slot_mapping_fp8 = get_empty ((self .max_num_tokens , ),
344- cache_name = "slot_mapping_fp8" ,
345- dtype = torch .int64 ,
346- capture_graph = capture_graph )
343+ self .slot_mapping_fp8 = self . get_empty ((self .max_num_tokens , ),
344+ cache_name = "slot_mapping_fp8" ,
345+ dtype = torch .int64 ,
346+ capture_graph = capture_graph )
347347 self .host_slot_mapping_fp8 = torch .zeros_like (
348348 self .slot_mapping_fp8 ,
349349 device = 'cpu' ,
350350 pin_memory = True ,
351351 )
352- self .slot_mapping_scale = get_empty ((self .max_num_tokens , ),
353- cache_name = "slot_mapping_scale" ,
354- dtype = torch .int64 ,
355- capture_graph = capture_graph )
352+ self .slot_mapping_scale = self .get_empty (
353+ (self .max_num_tokens , ),
354+ cache_name = "slot_mapping_scale" ,
355+ dtype = torch .int64 ,
356+ capture_graph = capture_graph )
356357 self .host_slot_mapping_scale = torch .zeros_like (
357358 self .slot_mapping_scale ,
358359 device = 'cpu' ,
359360 pin_memory = True ,
360361 )
361362 # Per-token request index buffer for topk_indices conversion
362- self .req_idx_per_token = get_empty ((self .max_num_tokens , ),
363- cache_name = "req_idx_per_token" ,
364- dtype = torch .int32 ,
365- capture_graph = capture_graph )
363+ self .req_idx_per_token = self . get_empty ((self .max_num_tokens , ),
364+ cache_name = "req_idx_per_token" ,
365+ dtype = torch .int32 ,
366+ capture_graph = capture_graph )
366367 # Block table for topk_indices conversion (shared for context and generation)
367- self .block_table = get_empty (
368+ self .block_table = self . get_empty (
368369 (self .max_num_requests , self .kv_cache_manager .max_blocks_per_seq ),
369370 cache_name = "block_table" ,
370371 dtype = torch .int32 ,
371372 capture_graph = capture_graph )
372- self .scheduler_metadata_buffer = get_empty (
373+ self .scheduler_metadata_buffer = self . get_empty (
373374 (self .num_sms + 1 , 2 ),
374375 cache_name = "scheduler_metadata_buffer" ,
375376 dtype = torch .int32 ,
376377 capture_graph = capture_graph )
377- self .cu_seqlen_ks = get_empty ((self .max_num_tokens , ),
378- cache_name = "cu_seqlen_ks" ,
379- dtype = torch .int32 ,
380- capture_graph = capture_graph )
381- self .cu_seqlen_ke = get_empty ((self .max_num_tokens , ),
382- cache_name = "cu_seqlen_ke" ,
383- dtype = torch .int32 ,
384- capture_graph = capture_graph )
378+ self .cu_seqlen_ks = self . get_empty ((self .max_num_tokens , ),
379+ cache_name = "cu_seqlen_ks" ,
380+ dtype = torch .int32 ,
381+ capture_graph = capture_graph )
382+ self .cu_seqlen_ke = self . get_empty ((self .max_num_tokens , ),
383+ cache_name = "cu_seqlen_ke" ,
384+ dtype = torch .int32 ,
385+ capture_graph = capture_graph )
385386
386387 def prepare (self ):
387388 super ().prepare ()
0 commit comments