|
20 | 20 |
|
21 | 21 | import paddle.jit.dy2static.utils as jit_utils
|
22 | 22 | import paddle.nn.layer
|
| 23 | +from paddle.base.core import CUDAGraph |
23 | 24 | from paddle.device.cuda import graphs
|
24 | 25 |
|
25 | 26 | from fastdeploy import envs
|
@@ -85,17 +86,14 @@ def run_impl_guard(self):
|
85 | 86 | class CudaGraphPiecewiseBackend:
|
86 | 87 | """Manage the capture and replay of CUDA graphs at the subgraph level."""
|
87 | 88 |
|
88 |
| - def __init__( |
89 |
| - self, |
90 |
| - fd_config: FDConfig, |
91 |
| - runnable: Callable, |
92 |
| - ): |
| 89 | + def __init__(self, fd_config: FDConfig, runnable: Callable): |
93 | 90 | self.fd_config = fd_config
|
94 | 91 | self.runnable = runnable
|
95 | 92 | self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
|
96 | 93 | self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
|
97 | 94 | self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
|
98 |
| - |
| 95 | + if self.fd_config.graph_opt_config.use_unique_memory_pool: |
| 96 | + self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id() |
99 | 97 | self._create_entry_dict()
|
100 | 98 |
|
101 | 99 | self.cuda_graph_manager = None
|
@@ -168,7 +166,11 @@ def __call__(self, **kwargs):
|
168 | 166 | input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
|
169 | 167 | entry.input_addresses = input_addresses
|
170 | 168 |
|
171 |
| - new_grpah = graphs.CUDAGraph() |
| 169 | + new_grpah = ( |
| 170 | + graphs.CUDAGraph(pool_id=self.unique_memory_pool_id) |
| 171 | + if self.fd_config.graph_opt_config.use_unique_memory_pool |
| 172 | + else graphs.CUDAGraph() |
| 173 | + ) |
172 | 174 | paddle.device.synchronize()
|
173 | 175 |
|
174 | 176 | # Capture
|
|
0 commit comments