Skip to content

Commit 9c2a4a9

Browse files
authored
Refactor to support generic dynamic graph models (#4718)
1 parent 33d0205 commit 9c2a4a9

32 files changed

+64
-30
lines changed

paddlex/inference/models/base/predictor/base_predictor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ....utils.benchmark import ENTRY_POINT_NAME, benchmark
3636
from ....utils.hpi import HPIConfig, HPIInfo
3737
from ....utils.io import YAMLReader
38+
from ....utils.model_paths import get_model_paths
3839
from ....utils.pp_option import PaddlePredictorOption
3940
from ...common import HPInfer, PaddleInfer
4041
from ...common.genai import GenAIClient, GenAIConfig, need_local_model
@@ -157,13 +158,20 @@ def __init__(
157158

158159
if self._use_local_model:
159160
self._use_hpip = use_hpip
160-
if not use_hpip:
161-
self._pp_option = self._prepare_pp_option(pp_option, device)
161+
model_paths = get_model_paths(self.model_dir)
162+
if "paddle_dyn" in model_paths or "safetensors" in model_paths:
163+
self._use_static_model = False
162164
else:
163-
require_hpip()
164-
self._hpi_config = self._prepare_hpi_config(hpi_config, device)
165+
self._use_static_model = True
166+
if self._use_static_model:
167+
if not use_hpip:
168+
self._pp_option = self._prepare_pp_option(pp_option, device)
169+
else:
170+
require_hpip()
171+
self._hpi_config = self._prepare_hpi_config(hpi_config, device)
165172
else:
166173
self._use_hpip = False
174+
self._use_static_model = False
167175

168176
logging.debug(f"{self.__class__.__name__}: {self.model_dir}")
169177

0 commit comments

Comments
 (0)