Skip to content

Commit f1ea1d9

Browse files
authored
Simplify CUDAGraph creation logic
Refactor CUDAGraph initialization to always use unique memory pool if configured.
1 parent b176cba commit f1ea1d9

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import paddle.jit.dy2static.utils as jit_utils
2222
import paddle.nn.layer
23-
from paddle.base.core import CUDAGraph
2423
from paddle.device.cuda import graphs
2524

2625
from fastdeploy import envs
@@ -92,7 +91,10 @@ def __init__(self, fd_config: FDConfig, runnable: Callable):
9291
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
9392
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
9493
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
94+
self.unique_memory_pool_id = None
9595
if self.fd_config.graph_opt_config.use_unique_memory_pool:
96+
from paddle.base.core import CUDAGraph
97+
9698
self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id()
9799
self._create_entry_dict()
98100

@@ -166,11 +168,7 @@ def __call__(self, **kwargs):
166168
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
167169
entry.input_addresses = input_addresses
168170

169-
new_grpah = (
170-
graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
171-
if self.fd_config.graph_opt_config.use_unique_memory_pool
172-
else graphs.CUDAGraph()
173-
)
171+
new_grpah = graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
174172
paddle.device.synchronize()
175173

176174
# Capture

0 commit comments

Comments
 (0)