22import torch
33import copy
44import bisect
5+ from collections import OrderedDict
56from typing import Optional
67from lightllm .utils .log_utils import init_logger
78from lightllm .utils .envs_utils import get_env_start_args
89from lightllm .distributed import dist_group_manager , lightllm_capture_graph , CustomProcessGroup
910from lightllm .common .basemodel .batch_objs import ModelInput , ModelOutput
1011from .infer_struct import InferStateInfo
1112
12-
1313logger = init_logger (__name__ )
1414
1515
1616class CudaGraph :
1717 # CudaGraph forward pass for the decoding stage.
1818
1919 def __init__ (self , max_batch_size = 8 , max_len_in_batch = 8192 ):
20- self .graph = {}
20+ self .graph = OrderedDict () # for LRU
21+
2122 self .mempool = torch .cuda .graph_pool_handle () if torch .cuda .is_available () else None
2223 self .max_batch_size = max_batch_size
2324 self .graph_max_len_in_batch = max_len_in_batch
2425 self .args = get_env_start_args ()
2526 self .enable_decode_microbatch_overlap = self .args .enable_decode_microbatch_overlap
27+ self .max_graph_pool_size = self .args .max_graph_pool_size
2628
2729 # gen cuda graph batch_sizes
2830 # cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
@@ -47,12 +49,22 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
4749 def can_run (self , batch_size , max_len_in_batch ):
4850 return batch_size <= self .max_batch_size and max_len_in_batch <= self .graph_max_len_in_batch
4951
50- def need_capture (self , batch_size ):
51- find_batch_size = self .find_closest_graph_batch_size (batch_size )
52- if find_batch_size is not None :
53- return find_batch_size not in self .graph
52+ def get_graph (self , batch_size ):
53+ # we assume batch_size is already dealed with find_closest_graph_batch_size outside
54+ # If the graph already exists, dequeue it and then enqueue it,
55+ # thus move it to the most recently used position.
56+ if batch_size in self .graph :
57+ find_graph = self .graph .pop (batch_size )
58+ self .graph [batch_size ] = find_graph
59+ return find_graph
5460 else :
55- assert False , "dead code"
61+ return None
62+
63+ def evict_oldest_graph (self ):
64+ if self .graph :
65+ oldest_batch_size , oldest_graph = self .graph .popitem (last = False )
66+ del oldest_graph
67+ torch .cuda .empty_cache ()
5668
5769 def find_closest_graph_batch_size (self , batch_size ):
5870 index = bisect .bisect_left (self .cuda_graph_batch_sizes , batch_size )
@@ -64,6 +76,9 @@ def find_closest_graph_batch_size(self, batch_size):
6476
6577 def _capture_decode (self , decode_func , input_ids : torch .Tensor , infer_state : InferStateInfo ):
6678 dist_group : CustomProcessGroup = infer_state .dist_group
79+ if len (self .graph ) >= self .max_graph_pool_size :
80+ self .evict_oldest_graph ()
81+
6782 graph_obj = torch .cuda .CUDAGraph ()
6883 batch_size = input_ids .shape [0 ]
6984 infer_state .max_len_in_batch = self .graph_max_len_in_batch
@@ -84,6 +99,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf
8499 with lightllm_capture_graph (dist_group ):
85100 with torch .cuda .graph (graph_obj , pool = self .mempool ):
86101 model_output = decode_func (input_ids , infer_state )
102+ # we assume batch_size is already dealed with find_closest_graph_batch_size outside
87103 self .graph [batch_size ] = (graph_obj , input_ids , infer_state , model_output )
88104 graph_obj .replay ()
89105 return model_output
@@ -97,6 +113,9 @@ def _capture_decode_overlap(
97113 infer_state1 : InferStateInfo ,
98114 ):
99115 dist_group : CustomProcessGroup = infer_state .dist_group
116+ if len (self .graph ) >= self .max_graph_pool_size :
117+ self .evict_oldest_graph ()
118+
100119 dist_group1 = infer_state1 .dist_group
101120 graph_obj = torch .cuda .CUDAGraph ()
102121 batch_size = input_ids .shape [0 ]
@@ -113,6 +132,7 @@ def _capture_decode_overlap(
113132 with lightllm_capture_graph (dist_group ):
114133 with torch .cuda .graph (graph_obj , pool = self .mempool ):
115134 model_output , model_output1 = decode_func (input_ids , infer_state , input_ids1 , infer_state1 )
135+ # we assume batch_size is already dealed with find_closest_graph_batch_size outside
116136 self .graph [batch_size ] = (
117137 graph_obj ,
118138 input_ids ,
0 commit comments