Skip to content

Commit bab7790

Browse files
authored
[CudaGraph] support cudagraph use shared pool (#4199)
* support cudagraph use shared pool * add envs * change CUDAGRAPH_POOL_ID to int * change CUDAGRAPH_POOL_ID to use_memory_pool * unify use_unique_memory_pool * fix use_unique_memory_pool
1 parent e2b68b3 commit bab7790

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

fastdeploy/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,9 @@ def __init__(
588588
Thus this flag cannot be used together with splitting_ops."""
589589
self.full_cuda_graph: bool = True
590590

591+
""" Whether to use shared memory pool for multi capture_size """
592+
self.use_unique_memory_pool: bool = False
593+
591594
self.max_capture_size: int = None
592595
self.real_shape_to_captured_size: dict[int, int] = None
593596
# CINN Config ...

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import paddle.jit.dy2static.utils as jit_utils
2222
import paddle.nn.layer
23+
from paddle.base.core import CUDAGraph
2324
from paddle.device.cuda import graphs
2425

2526
from fastdeploy import envs
@@ -85,17 +86,14 @@ def run_impl_guard(self):
8586
class CudaGraphPiecewiseBackend:
8687
"""Manage the capture and replay of CUDA graphs at the subgraph level."""
8788

88-
def __init__(
89-
self,
90-
fd_config: FDConfig,
91-
runnable: Callable,
92-
):
89+
def __init__(self, fd_config: FDConfig, runnable: Callable):
9390
self.fd_config = fd_config
9491
self.runnable = runnable
9592
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
9693
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
9794
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
98-
95+
if self.fd_config.graph_opt_config.use_unique_memory_pool:
96+
self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id()
9997
self._create_entry_dict()
10098

10199
self.cuda_graph_manager = None
@@ -168,7 +166,11 @@ def __call__(self, **kwargs):
168166
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
169167
entry.input_addresses = input_addresses
170168

171-
new_grpah = graphs.CUDAGraph()
169+
new_grpah = (
170+
graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
171+
if self.fd_config.graph_opt_config.use_unique_memory_pool
172+
else graphs.CUDAGraph()
173+
)
172174
paddle.device.synchronize()
173175

174176
# Capture

0 commit comments

Comments
 (0)