|
35 | 35 | from ....utils.benchmark import ENTRY_POINT_NAME, benchmark |
36 | 36 | from ....utils.hpi import HPIConfig, HPIInfo |
37 | 37 | from ....utils.io import YAMLReader |
| 38 | +from ....utils.model_paths import get_model_paths |
38 | 39 | from ....utils.pp_option import PaddlePredictorOption |
39 | 40 | from ...common import HPInfer, PaddleInfer |
40 | 41 | from ...common.genai import GenAIClient, GenAIConfig, need_local_model |
@@ -157,13 +158,20 @@ def __init__( |
157 | 158 |
|
158 | 159 | if self._use_local_model: |
159 | 160 | self._use_hpip = use_hpip |
160 | | - if not use_hpip: |
161 | | - self._pp_option = self._prepare_pp_option(pp_option, device) |
| 161 | + model_paths = get_model_paths(self.model_dir) |
| 162 | + if "paddle_dyn" in model_paths or "safetensors" in model_paths: |
| 163 | + self._use_static_model = False |
162 | 164 | else: |
163 | | - require_hpip() |
164 | | - self._hpi_config = self._prepare_hpi_config(hpi_config, device) |
| 165 | + self._use_static_model = True |
| 166 | + if self._use_static_model: |
| 167 | + if not use_hpip: |
| 168 | + self._pp_option = self._prepare_pp_option(pp_option, device) |
| 169 | + else: |
| 170 | + require_hpip() |
| 171 | + self._hpi_config = self._prepare_hpi_config(hpi_config, device) |
165 | 172 | else: |
166 | 173 | self._use_hpip = False |
| 174 | + self._use_static_model = False |
167 | 175 |
|
168 | 176 | logging.debug(f"{self.__class__.__name__}: {self.model_dir}") |
169 | 177 |
|
|
0 commit comments