Skip to content

Commit 870364b

Browse files
authored
[CUDAGraph]CUDA Graph support unique memory pool (#4230)
* cuda graph use unique memory pool * fix custom device import bug * refine code * refine code * refine code
1 parent 5ff10c8 commit 870364b

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

fastdeploy/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,8 +841,13 @@ def __init__(
841841
Now don't support capture both decode-only and prefill-only"""
842842
self.full_cuda_graph: bool = True
843843

844+
""" Maximum CUDA Graph capture size """
844845
self.max_capture_size: int = None
846+
""" Record maps mapped from real shape to captured size to reduce runtime overhead """
845847
self.real_shape_to_captured_size: dict[int, int] = None
848+
""" Whether to use shared memory pool for multi capture_size """
849+
self.use_unique_memory_pool: bool = False
850+
846851
# CINN Config ...
847852
if args is not None:
848853
for key, value in args.items():

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ def __init__(
9696
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
9797
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
9898
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
99+
self.unique_memory_pool_id = None
100+
if self.fd_config.graph_opt_config.use_unique_memory_pool:
101+
# TODO(gongshaotian): Optimize code
102+
if paddle.is_compiled_with_cuda():
103+
from paddle.base.core import CUDAGraph
104+
105+
self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id()
99106

100107
self._create_entry_dict()
101108

@@ -169,7 +176,7 @@ def __call__(self, **kwargs) -> List[paddle.Tensor] | paddle.Tensor:
169176
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
170177
entry.input_addresses = input_addresses
171178

172-
new_grpah = graphs.CUDAGraph()
179+
new_grpah = graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
173180
paddle.device.synchronize()
174181

175182
# Capture

0 commit comments

Comments
 (0)