@@ -286,18 +286,11 @@ def __post_init__(self):
286286
287287 capture_graph = torch .cuda .is_current_stream_capturing ()
288288
289- def get_empty (tensor_shape : list [int ], dtype : torch .dtype ,
290- cache_name : str ) -> torch .Tensor :
291- if self .cuda_graph_buffers is None :
292- return torch .zeros (tensor_shape , device = 'cuda' , dtype = dtype )
293- return self .cuda_graph_buffers .get_buffer (tensor_shape , dtype ,
294- cache_name , capture_graph )
295-
296289 self .indexer_k_cache_block_offsets = get_empty (
297290 [self .max_num_sequences , self .kv_cache_manager .max_blocks_per_seq ],
298291 cache_name = "indexer_k_cache_block_offsets" ,
299292 dtype = torch .int32 ,
300- )
293+ capture_graph = capture_graph )
301294 self .host_indexer_k_cache_block_offsets = torch .zeros_like (
302295 self .indexer_k_cache_block_offsets ,
303296 device = 'cpu' ,
@@ -310,17 +303,16 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
310303 (self .max_num_requests + 1 , ),
311304 cache_name = "ctx_cached_token_indptr" ,
312305 dtype = torch .int64 ,
313- )
306+ capture_graph = capture_graph )
314307 self .host_ctx_cached_token_indptr = torch .zeros_like (
315308 self .ctx_cached_token_indptr ,
316309 device = 'cpu' ,
317310 pin_memory = True ,
318311 )
319- self .ctx_kv_indptr = get_empty (
320- (self .max_num_requests + 1 , ),
321- cache_name = "ctx_kv_indptr" ,
322- dtype = torch .int64 ,
323- )
312+ self .ctx_kv_indptr = get_empty ((self .max_num_requests + 1 , ),
313+ cache_name = "ctx_kv_indptr" ,
314+ dtype = torch .int64 ,
315+ capture_graph = capture_graph )
324316 self .host_ctx_kv_indptr = torch .zeros_like (
325317 self .ctx_kv_indptr ,
326318 device = 'cpu' ,
@@ -331,71 +323,65 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
331323 (self .max_num_requests + 1 , ),
332324 cache_name = "gen_cached_token_indptr" ,
333325 dtype = torch .int64 ,
334- )
326+ capture_graph = capture_graph )
335327 self .host_gen_cached_token_indptr = torch .zeros_like (
336328 self .gen_cached_token_indptr ,
337329 device = 'cpu' ,
338330 pin_memory = True ,
339331 )
340- self .gen_kv_indptr = get_empty (
341- (self .max_num_requests + 1 , ),
342- cache_name = "gen_kv_indptr" ,
343- dtype = torch .int64 ,
344- )
332+ self .gen_kv_indptr = get_empty ((self .max_num_requests + 1 , ),
333+ cache_name = "gen_kv_indptr" ,
334+ dtype = torch .int64 ,
335+ capture_graph = capture_graph )
345336 self .host_gen_kv_indptr = torch .zeros_like (
346337 self .gen_kv_indptr ,
347338 device = 'cpu' ,
348339 pin_memory = True ,
349340 )
350341 # Indexer metadata
351342 # Separate slot mappings for non-interleaved layout (flat byte indices)
352- self .slot_mapping_fp8 = get_empty (
353- (self .max_num_tokens , ),
354- cache_name = "slot_mapping_fp8" ,
355- dtype = torch .int64 ,
356- )
343+ self .slot_mapping_fp8 = get_empty ((self .max_num_tokens , ),
344+ cache_name = "slot_mapping_fp8" ,
345+ dtype = torch .int64 ,
346+ capture_graph = capture_graph )
357347 self .host_slot_mapping_fp8 = torch .zeros_like (
358348 self .slot_mapping_fp8 ,
359349 device = 'cpu' ,
360350 pin_memory = True ,
361351 )
362- self .slot_mapping_scale = get_empty (
363- (self .max_num_tokens , ),
364- cache_name = "slot_mapping_scale" ,
365- dtype = torch .int64 ,
366- )
352+ self .slot_mapping_scale = get_empty ((self .max_num_tokens , ),
353+ cache_name = "slot_mapping_scale" ,
354+ dtype = torch .int64 ,
355+ capture_graph = capture_graph )
367356 self .host_slot_mapping_scale = torch .zeros_like (
368357 self .slot_mapping_scale ,
369358 device = 'cpu' ,
370359 pin_memory = True ,
371360 )
372361 # Per-token request index buffer for topk_indices conversion
373- self .req_idx_per_token = get_empty (
374- (self .max_num_tokens , ),
375- cache_name = "req_idx_per_token" ,
376- dtype = torch .int32 ,
377- )
362+ self .req_idx_per_token = get_empty ((self .max_num_tokens , ),
363+ cache_name = "req_idx_per_token" ,
364+ dtype = torch .int32 ,
365+ capture_graph = capture_graph )
378366 # Block table for topk_indices conversion (shared for context and generation)
379367 self .block_table = get_empty (
380368 (self .max_num_requests , self .kv_cache_manager .max_blocks_per_seq ),
381369 cache_name = "block_table" ,
382370 dtype = torch .int32 ,
383- )
371+ capture_graph = capture_graph )
384372 self .scheduler_metadata_buffer = get_empty (
385373 (self .num_sms + 1 , 2 ),
386374 cache_name = "scheduler_metadata_buffer" ,
387375 dtype = torch .int32 ,
388- )
389- self .cu_seqlen_ks = get_empty (
390- (self .max_num_tokens , ),
391- cache_name = "cu_seqlen_ks" ,
392- dtype = torch .int32 ,
393- )
394- self .cu_seqlen_ke = get_empty (
395- (self .max_num_tokens , ),
396- cache_name = "cu_seqlen_ke" ,
397- dtype = torch .int32 ,
398- )
376+ capture_graph = capture_graph )
377+ self .cu_seqlen_ks = get_empty ((self .max_num_tokens , ),
378+ cache_name = "cu_seqlen_ks" ,
379+ dtype = torch .int32 ,
380+ capture_graph = capture_graph )
381+ self .cu_seqlen_ke = get_empty ((self .max_num_tokens , ),
382+ cache_name = "cu_seqlen_ke" ,
383+ dtype = torch .int32 ,
384+ capture_graph = capture_graph )
399385
400386 def prepare (self ):
401387 super ().prepare ()
0 commit comments