@@ -61,7 +61,7 @@ class CUDAGraphRunnerConfig:
6161 max_beam_width : int
6262 max_num_tokens : int
6363 spec_config : Optional [DecodingBaseConfig ]
64- cuda_graph_mem_pool : Any
64+ cuda_graph_mem_pool : torch . cuda . MemPool
6565 use_mrope : bool
6666 original_max_draft_len : int
6767 original_max_total_draft_tokens : int
@@ -98,7 +98,9 @@ def __init__(self, config: CUDAGraphRunnerConfig):
9898 self .graph_outputs : Dict [Tuple [int , int , int ],
9999 Callable [[], Optional [torch .Tensor ]]] = {}
100100 self .graph_metadata : Dict [Tuple [int , int , int ], Dict [str , Any ]] = {}
101- self .memory_pool = config .cuda_graph_mem_pool
101+ self .memory_pool = config .cuda_graph_mem_pool if config .cuda_graph_mem_pool else torch .cuda .MemPool (
102+ )
103+ self .memory_pool_handle = self .memory_pool .id
102104 self .padding_dummy_request : Optional ["Request" ] = None
103105
104106 self .shared_static_tensors : Dict [str , torch .Tensor ] = {}
@@ -293,15 +295,14 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
293295 if postprocess_fn is not None :
294296 postprocess_fn (capture_inputs )
295297
296- with torch .cuda .graph (graph , pool = self .memory_pool ):
298+ with torch .cuda .graph (graph , pool = self .memory_pool_handle ):
297299 output = _setup_spec_decoding_and_forward (
298300 key , forward_fn , capture_inputs )
299301 if postprocess_fn is not None :
300302 postprocess_fn (capture_inputs )
301303
302304 self .graphs [key ] = graph
303305 self .graph_outputs [key ] = make_weak_ref (output )
304- self .memory_pool = graph .pool ()
305306
306307 def replay (self , key : Tuple [int , int , int ],
307308 current_inputs : Dict [str , Any ]) -> Optional [torch .Tensor ]:
@@ -427,6 +428,6 @@ def clear(self):
427428 self .graph_outputs .clear ()
428429 self .graph_metadata .clear ()
429430 self .padding_dummy_request = None
430- del self .memory_pool
431- self .memory_pool = None
431+ del self .memory_pool_handle
432+ self .memory_pool_handle = None
432433 torch .cuda .empty_cache ()
0 commit comments