Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import paddle.jit.dy2static.utils as jit_utils
import paddle.nn.layer
from paddle.base.core import CUDAGraph
from paddle.device.cuda import graphs

from fastdeploy import envs
Expand Down Expand Up @@ -92,7 +91,10 @@ def __init__(self, fd_config: FDConfig, runnable: Callable):
self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes
self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups
self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size
self.unique_memory_pool_id = None
if self.fd_config.graph_opt_config.use_unique_memory_pool:
from paddle.base.core import CUDAGraph

self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id()
self._create_entry_dict()

Expand Down Expand Up @@ -166,11 +168,7 @@ def __call__(self, **kwargs):
input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)]
entry.input_addresses = input_addresses

new_grpah = (
graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
if self.fd_config.graph_opt_config.use_unique_memory_pool
else graphs.CUDAGraph()
)
new_grpah = graphs.CUDAGraph(pool_id=self.unique_memory_pool_id)
paddle.device.synchronize()

# Capture
Expand Down
Loading