diff --git a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py index 79ff9ea0e1..3465b60928 100644 --- a/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py +++ b/fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py @@ -20,7 +20,6 @@ import paddle.jit.dy2static.utils as jit_utils import paddle.nn.layer -from paddle.base.core import CUDAGraph from paddle.device.cuda import graphs from fastdeploy import envs @@ -92,8 +91,12 @@ def __init__(self, fd_config: FDConfig, runnable: Callable): self.cudagraph_capture_sizes = fd_config.graph_opt_config.cudagraph_capture_sizes self.warm_up_size = fd_config.graph_opt_config.cudagraph_num_of_warmups self.real_shape_to_captured_size = fd_config.graph_opt_config.real_shape_to_captured_size + self.unique_memory_pool_id = None if self.fd_config.graph_opt_config.use_unique_memory_pool: - self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id() + if paddle.is_compiled_with_cuda(): + from paddle.base.core import CUDAGraph + + self.unique_memory_pool_id = CUDAGraph.gen_new_memory_pool_id() self._create_entry_dict() self.cuda_graph_manager = None @@ -166,11 +169,7 @@ def __call__(self, **kwargs): input_addresses = [x.data_ptr() for (_, x) in kwargs.items() if isinstance(x, paddle.Tensor)] entry.input_addresses = input_addresses - new_grpah = ( - graphs.CUDAGraph(pool_id=self.unique_memory_pool_id) - if self.fd_config.graph_opt_config.use_unique_memory_pool - else graphs.CUDAGraph() - ) + new_grpah = graphs.CUDAGraph(pool_id=self.unique_memory_pool_id) paddle.device.synchronize() # Capture