@@ -305,6 +305,7 @@ def __post_init__(self):
305305 capture_graph = torch .cuda .is_current_stream_capturing ()
306306
307307 self .indexer_k_cache_block_offsets = self .get_empty (
308+ self .cuda_graph_buffers ,
308309 [self .max_num_sequences , self .kv_cache_manager .max_blocks_per_seq ],
309310 cache_name = "indexer_k_cache_block_offsets" ,
310311 dtype = torch .int32 ,
@@ -318,7 +319,7 @@ def __post_init__(self):
318319 # For mla_rope_append_paged_kv_assign_q
319320 if not self .enable_context_mla_with_cached_kv :
320321 self .ctx_cached_token_indptr = self .get_empty (
321- (self .max_num_requests + 1 , ),
322+ self . cuda_graph_buffers , (self .max_num_requests + 1 , ),
322323 cache_name = "ctx_cached_token_indptr" ,
323324 dtype = torch .int64 ,
324325 capture_graph = capture_graph )
@@ -327,7 +328,8 @@ def __post_init__(self):
327328 device = 'cpu' ,
328329 pin_memory = True ,
329330 )
330- self .ctx_kv_indptr = self .get_empty ((self .max_num_requests + 1 , ),
331+ self .ctx_kv_indptr = self .get_empty (self .cuda_graph_buffers ,
332+ (self .max_num_requests + 1 , ),
331333 cache_name = "ctx_kv_indptr" ,
332334 dtype = torch .int64 ,
333335 capture_graph = capture_graph )
@@ -338,7 +340,7 @@ def __post_init__(self):
338340 )
339341 # New generation buffers for dsa
340342 self .gen_cached_token_indptr = self .get_empty (
341- (self .max_num_requests + 1 , ),
343+ self . cuda_graph_buffers , (self .max_num_requests + 1 , ),
342344 cache_name = "gen_cached_token_indptr" ,
343345 dtype = torch .int64 ,
344346 capture_graph = capture_graph )
@@ -347,7 +349,8 @@ def __post_init__(self):
347349 device = 'cpu' ,
348350 pin_memory = True ,
349351 )
350- self .gen_kv_indptr = self .get_empty ((self .max_num_requests + 1 , ),
352+ self .gen_kv_indptr = self .get_empty (self .cuda_graph_buffers ,
353+ (self .max_num_requests + 1 , ),
351354 cache_name = "gen_kv_indptr" ,
352355 dtype = torch .int64 ,
353356 capture_graph = capture_graph )
@@ -358,7 +361,8 @@ def __post_init__(self):
358361 )
359362 # Indexer metadata
360363 # Separate slot mappings for non-interleaved layout (flat byte indices)
361- self .slot_mapping_fp8 = self .get_empty ((self .max_num_tokens , ),
364+ self .slot_mapping_fp8 = self .get_empty (self .cuda_graph_buffers ,
365+ (self .max_num_tokens , ),
362366 cache_name = "slot_mapping_fp8" ,
363367 dtype = torch .int64 ,
364368 capture_graph = capture_graph )
@@ -368,7 +372,7 @@ def __post_init__(self):
368372 pin_memory = True ,
369373 )
370374 self .slot_mapping_scale = self .get_empty (
371- (self .max_num_tokens , ),
375+ self . cuda_graph_buffers , (self .max_num_tokens , ),
372376 cache_name = "slot_mapping_scale" ,
373377 dtype = torch .int64 ,
374378 capture_graph = capture_graph )
@@ -378,26 +382,30 @@ def __post_init__(self):
378382 pin_memory = True ,
379383 )
380384 # Per-token request index buffer for topk_indices conversion
381- self .req_idx_per_token = self .get_empty ((self .max_num_tokens , ),
385+ self .req_idx_per_token = self .get_empty (self .cuda_graph_buffers ,
386+ (self .max_num_tokens , ),
382387 cache_name = "req_idx_per_token" ,
383388 dtype = torch .int32 ,
384389 capture_graph = capture_graph )
385390 # Block table for topk_indices conversion (shared for context and generation)
386391 self .block_table = self .get_empty (
392+ self .cuda_graph_buffers ,
387393 (self .max_num_requests , self .kv_cache_manager .max_blocks_per_seq ),
388394 cache_name = "block_table" ,
389395 dtype = torch .int32 ,
390396 capture_graph = capture_graph )
391397 self .scheduler_metadata_buffer = self .get_empty (
392- (self .num_sms + 1 , 2 ),
398+ self . cuda_graph_buffers , (self .num_sms + 1 , 2 ),
393399 cache_name = "scheduler_metadata_buffer" ,
394400 dtype = torch .int32 ,
395401 capture_graph = capture_graph )
396- self .cu_seqlen_ks = self .get_empty ((self .max_num_tokens , ),
402+ self .cu_seqlen_ks = self .get_empty (self .cuda_graph_buffers ,
403+ (self .max_num_tokens , ),
397404 cache_name = "cu_seqlen_ks" ,
398405 dtype = torch .int32 ,
399406 capture_graph = capture_graph )
400- self .cu_seqlen_ke = self .get_empty ((self .max_num_tokens , ),
407+ self .cu_seqlen_ke = self .get_empty (self .cuda_graph_buffers ,
408+ (self .max_num_tokens , ),
401409 cache_name = "cu_seqlen_ke" ,
402410 dtype = torch .int32 ,
403411 capture_graph = capture_graph )
0 commit comments