Skip to content

Commit 98a09a0

Browse files
shihaobairoot
authored andcommitted
cuda graph pool with LRU
1 parent 81c5f61 commit 98a09a0

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ def _decode(
346346
) -> ModelOutput:
347347
if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch):
348348
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
349+
assert find_graph_batch_size is not None
350+
349351
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
350352
infer_state = self._create_inferstate(padded_model_input)
351353
copy_kv_index_to_req(
@@ -356,7 +358,7 @@ def _decode(
356358
)
357359
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
358360

359-
if self.graph.need_capture(find_graph_batch_size):
361+
if self.graph.get_graph(find_graph_batch_size) is None:
360362
infer_state.is_cuda_graph = True
361363
model_output: ModelOutput = self.graph.capture_decode(
362364
self._token_forward, padded_model_input.input_ids, infer_state
@@ -497,6 +499,8 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
497499

498500
if self.graph is not None and self.graph.can_run(origin_batch_size, max_len_in_batch):
499501
find_graph_batch_size = self.graph.find_closest_graph_batch_size(origin_batch_size)
502+
assert find_graph_batch_size is not None
503+
500504
padded_model_input0 = self._create_padded_decode_model_input(model_input0, find_graph_batch_size)
501505
padded_model_input1 = self._create_padded_decode_model_input(model_input1, find_graph_batch_size)
502506
infer_state0 = self._create_inferstate(padded_model_input0, 0)
@@ -516,7 +520,7 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
516520
)
517521
infer_state1.init_some_extra_state(self, padded_model_input1.input_ids)
518522

519-
if self.graph.need_capture(find_graph_batch_size):
523+
if self.graph.get_graph(find_graph_batch_size) is None:
520524
infer_state0.is_cuda_graph = True
521525
infer_state1.is_cuda_graph = True
522526

lightllm/common/basemodel/cuda_graph.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,29 @@
22
import torch
33
import copy
44
import bisect
5+
from collections import OrderedDict
56
from typing import Optional
67
from lightllm.utils.log_utils import init_logger
78
from lightllm.utils.envs_utils import get_env_start_args
89
from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup
910
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
1011
from .infer_struct import InferStateInfo
1112

12-
1313
logger = init_logger(__name__)
1414

1515

1616
class CudaGraph:
1717
# CudaGraph forward pass for the decoding stage.
1818

1919
def __init__(self, max_batch_size=8, max_len_in_batch=8192):
20-
self.graph = {}
20+
self.graph = OrderedDict() # for LRU
21+
2122
self.mempool = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
2223
self.max_batch_size = max_batch_size
2324
self.graph_max_len_in_batch = max_len_in_batch
2425
self.args = get_env_start_args()
2526
self.enable_decode_microbatch_overlap = self.args.enable_decode_microbatch_overlap
27+
self.max_graph_pool_size = self.args.max_graph_pool_size
2628

2729
# gen cuda graph batch_sizes
2830
# cuda graph gen for batch size = [1, 2, 3, ..., graph_split_batch_size]
@@ -47,12 +49,22 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192):
4749
def can_run(self, batch_size, max_len_in_batch):
4850
return batch_size <= self.max_batch_size and max_len_in_batch <= self.graph_max_len_in_batch
4951

50-
def need_capture(self, batch_size):
51-
find_batch_size = self.find_closest_graph_batch_size(batch_size)
52-
if find_batch_size is not None:
53-
return find_batch_size not in self.graph
52+
def get_graph(self, batch_size):
53+
# we assume batch_size is already dealed with find_closest_graph_batch_size outside
54+
# If the graph already exists, dequeue it and then enqueue it,
55+
# thus move it to the most recently used position.
56+
if batch_size in self.graph:
57+
find_graph = self.graph.pop(batch_size)
58+
self.graph[batch_size] = find_graph
59+
return find_graph
5460
else:
55-
assert False, "dead code"
61+
return None
62+
63+
def evict_oldest_graph(self):
64+
if self.graph:
65+
oldest_batch_size, oldest_graph = self.graph.popitem(last=False)
66+
del oldest_graph
67+
torch.cuda.empty_cache()
5668

5769
def find_closest_graph_batch_size(self, batch_size):
5870
index = bisect.bisect_left(self.cuda_graph_batch_sizes, batch_size)
@@ -64,6 +76,9 @@ def find_closest_graph_batch_size(self, batch_size):
6476

6577
def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: InferStateInfo):
6678
dist_group: CustomProcessGroup = infer_state.dist_group
79+
if len(self.graph) >= self.max_graph_pool_size:
80+
self.evict_oldest_graph()
81+
6782
graph_obj = torch.cuda.CUDAGraph()
6883
batch_size = input_ids.shape[0]
6984
infer_state.max_len_in_batch = self.graph_max_len_in_batch
@@ -84,6 +99,7 @@ def _capture_decode(self, decode_func, input_ids: torch.Tensor, infer_state: Inf
8499
with lightllm_capture_graph(dist_group):
85100
with torch.cuda.graph(graph_obj, pool=self.mempool):
86101
model_output = decode_func(input_ids, infer_state)
102+
# we assume batch_size is already dealed with find_closest_graph_batch_size outside
87103
self.graph[batch_size] = (graph_obj, input_ids, infer_state, model_output)
88104
graph_obj.replay()
89105
return model_output
@@ -97,6 +113,9 @@ def _capture_decode_overlap(
97113
infer_state1: InferStateInfo,
98114
):
99115
dist_group: CustomProcessGroup = infer_state.dist_group
116+
if len(self.graph) >= self.max_graph_pool_size:
117+
self.evict_oldest_graph()
118+
100119
dist_group1 = infer_state1.dist_group
101120
graph_obj = torch.cuda.CUDAGraph()
102121
batch_size = input_ids.shape[0]
@@ -113,6 +132,7 @@ def _capture_decode_overlap(
113132
with lightllm_capture_graph(dist_group):
114133
with torch.cuda.graph(graph_obj, pool=self.mempool):
115134
model_output, model_output1 = decode_func(input_ids, infer_state, input_ids1, infer_state1)
135+
# we assume batch_size is already dealed with find_closest_graph_batch_size outside
116136
self.graph[batch_size] = (
117137
graph_obj,
118138
input_ids,

lightllm/server/api_cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
335335
)
336336
parser.add_argument("--disable_cudagraph", action="store_true", help="Disable the cudagraph of the decoding stage")
337337

338+
parser.add_argument(
339+
"--max_graph_pool_size",
340+
type=int,
341+
default=16,
342+
help="""Maximum cuda graph pool size for decoding stage.""",
343+
)
344+
338345
parser.add_argument(
339346
"--graph_max_batch_size",
340347
type=int,

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class StartArgs:
7777
visual_nccl_ports: List[int] = field(default_factory=lambda: [29500])
7878
enable_monitor_auth: bool = field(default=False)
7979
disable_cudagraph: bool = field(default=False)
80+
max_graph_pool_size: int = field(default=16)
8081
graph_max_batch_size: int = field(default=256)
8182
graph_split_batch_size: int = field(default=32)
8283
graph_grow_step_size: int = field(default=16)

0 commit comments

Comments
 (0)