@@ -287,6 +287,7 @@ def __post_init__(self):
287287 capture_graph = torch .cuda .is_current_stream_capturing ()
288288
289289 self .indexer_k_cache_block_offsets = self .get_empty (
290+ self .cuda_graph_buffers ,
290291 [self .max_num_sequences , self .kv_cache_manager .max_blocks_per_seq ],
291292 cache_name = "indexer_k_cache_block_offsets" ,
292293 dtype = torch .int32 ,
@@ -300,7 +301,7 @@ def __post_init__(self):
300301 # For mla_rope_append_paged_kv_assign_q
301302 if not self .enable_context_mla_with_cached_kv :
302303 self .ctx_cached_token_indptr = self .get_empty (
303- (self .max_num_requests + 1 , ),
304+ self . cuda_graph_buffers , (self .max_num_requests + 1 , ),
304305 cache_name = "ctx_cached_token_indptr" ,
305306 dtype = torch .int64 ,
306307 capture_graph = capture_graph )
@@ -309,7 +310,8 @@ def __post_init__(self):
309310 device = 'cpu' ,
310311 pin_memory = True ,
311312 )
312- self .ctx_kv_indptr = self .get_empty ((self .max_num_requests + 1 , ),
313+ self .ctx_kv_indptr = self .get_empty (self .cuda_graph_buffers ,
314+ (self .max_num_requests + 1 , ),
313315 cache_name = "ctx_kv_indptr" ,
314316 dtype = torch .int64 ,
315317 capture_graph = capture_graph )
@@ -320,7 +322,7 @@ def __post_init__(self):
320322 )
321323 # New generation buffers for dsa
322324 self .gen_cached_token_indptr = self .get_empty (
323- (self .max_num_requests + 1 , ),
325+ self . cuda_graph_buffers , (self .max_num_requests + 1 , ),
324326 cache_name = "gen_cached_token_indptr" ,
325327 dtype = torch .int64 ,
326328 capture_graph = capture_graph )
@@ -329,7 +331,8 @@ def __post_init__(self):
329331 device = 'cpu' ,
330332 pin_memory = True ,
331333 )
332- self .gen_kv_indptr = self .get_empty ((self .max_num_requests + 1 , ),
334+ self .gen_kv_indptr = self .get_empty (self .cuda_graph_buffers ,
335+ (self .max_num_requests + 1 , ),
333336 cache_name = "gen_kv_indptr" ,
334337 dtype = torch .int64 ,
335338 capture_graph = capture_graph )
@@ -340,7 +343,8 @@ def __post_init__(self):
340343 )
341344 # Indexer metadata
342345 # Separate slot mappings for non-interleaved layout (flat byte indices)
343- self .slot_mapping_fp8 = self .get_empty ((self .max_num_tokens , ),
346+ self .slot_mapping_fp8 = self .get_empty (self .cuda_graph_buffers ,
347+ (self .max_num_tokens , ),
344348 cache_name = "slot_mapping_fp8" ,
345349 dtype = torch .int64 ,
346350 capture_graph = capture_graph )
@@ -350,7 +354,7 @@ def __post_init__(self):
350354 pin_memory = True ,
351355 )
352356 self .slot_mapping_scale = self .get_empty (
353- (self .max_num_tokens , ),
357+ self . cuda_graph_buffers , (self .max_num_tokens , ),
354358 cache_name = "slot_mapping_scale" ,
355359 dtype = torch .int64 ,
356360 capture_graph = capture_graph )
@@ -360,26 +364,30 @@ def __post_init__(self):
360364 pin_memory = True ,
361365 )
362366 # Per-token request index buffer for topk_indices conversion
363- self .req_idx_per_token = self .get_empty ((self .max_num_tokens , ),
367+ self .req_idx_per_token = self .get_empty (self .cuda_graph_buffers ,
368+ (self .max_num_tokens , ),
364369 cache_name = "req_idx_per_token" ,
365370 dtype = torch .int32 ,
366371 capture_graph = capture_graph )
367372 # Block table for topk_indices conversion (shared for context and generation)
368373 self .block_table = self .get_empty (
374+ self .cuda_graph_buffers ,
369375 (self .max_num_requests , self .kv_cache_manager .max_blocks_per_seq ),
370376 cache_name = "block_table" ,
371377 dtype = torch .int32 ,
372378 capture_graph = capture_graph )
373379 self .scheduler_metadata_buffer = self .get_empty (
374- (self .num_sms + 1 , 2 ),
380+ self . cuda_graph_buffers , (self .num_sms + 1 , 2 ),
375381 cache_name = "scheduler_metadata_buffer" ,
376382 dtype = torch .int32 ,
377383 capture_graph = capture_graph )
378- self .cu_seqlen_ks = self .get_empty ((self .max_num_tokens , ),
384+ self .cu_seqlen_ks = self .get_empty (self .cuda_graph_buffers ,
385+ (self .max_num_tokens , ),
379386 cache_name = "cu_seqlen_ks" ,
380387 dtype = torch .int32 ,
381388 capture_graph = capture_graph )
382- self .cu_seqlen_ke = self .get_empty ((self .max_num_tokens , ),
389+ self .cu_seqlen_ke = self .get_empty (self .cuda_graph_buffers ,
390+ (self .max_num_tokens , ),
383391 cache_name = "cu_seqlen_ke" ,
384392 dtype = torch .int32 ,
385393 capture_graph = capture_graph )
0 commit comments