Skip to content

Commit e63cae7

Browse files
committed
recover need_capture
1 parent 112a28e commit e63cae7

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,9 +381,7 @@ def _decode(
381381
)
382382
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
383383

384-
# Check if a graph needs to be captured.
385-
# get_graph returns None if a graph for the batch_size doesn't exist.
386-
if self.graph.get_graph(find_graph_batch_size) is None:
384+
if self.graph.need_capture(find_graph_batch_size):
387385
infer_state.is_cuda_graph = True
388386
model_output: ModelOutput = self.graph.capture_decode(
389387
self._token_forward, padded_model_input.input_ids, infer_state
@@ -574,9 +572,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
574572
)
575573
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
576574

577-
# Check if a graph needs to be captured.
578-
# get_graph returns None if a graph for the batch_size doesn't exist.
579-
if self.graph.get_graph(find_graph_batch_size) is None:
575+
if self.graph.need_capture(find_graph_batch_size):
580576
infer_state0.is_cuda_graph = True
581577
infer_state1.is_cuda_graph = True
582578

lightllm/common/basemodel/cuda_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
4949
def can_run(self, batch_size, max_len_in_batch):
5050
return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch
5151

52-
def get_graph(self, batch_size):
52+
def need_capture(self, batch_size):
5353
# We assume batch_size has already been adjusted to the closest supported graph batch size
5454
# If the graph already exists, get it and move it to the most recently used position.
5555
if batch_size in self.graph:
5656
find_graph = self.graph.pop(batch_size) # Dequeue the graph
5757
self.graph[batch_size] = find_graph # Enqueue the graph for LRU
58-
return find_graph
58+
return False
5959
else:
60-
return None
60+
return True
6161

6262
def evict_oldest_graph(self):
6363
if self.graph:

0 commit comments

Comments
 (0)