diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py b/iotdb-core/ainode/ainode/core/manager/model_manager.py index 95fdda1456b14..ced7277c1ae2e 100644 --- a/iotdb-core/ainode/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py @@ -100,6 +100,18 @@ def load_model(self, model_id: str, acceleration: bool = False) -> Callable: logger.info(f"load model {model_id}") return self.model_storage.load_model(model_id, acceleration) + def get_ckpt_path(self, model_id: str) -> str: + """ + Get the checkpoint path for a given model ID. + + Args: + model_id (str): The ID of the model. + + Returns: + str: The path to the checkpoint file for the model. + """ + return self.model_storage.get_ckpt_path(model_id) + @staticmethod def load_built_in_model(model_id: str, attributes: {}): model_id = model_id.lower() diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py b/iotdb-core/ainode/ainode/core/model/model_storage.py index c0e2a21c80a8d..864b5c30e0abf 100644 --- a/iotdb-core/ainode/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/ainode/core/model/model_storage.py @@ -119,3 +119,15 @@ def delete_model(self, model_id: str) -> None: def _remove_from_cache(self, file_path: str) -> None: if file_path in self._model_cache: del self._model_cache[file_path] + + def get_ckpt_path(self, model_id: str) -> str: + """ + Get the checkpoint path for a given model ID. + + Args: + model_id (str): The ID of the model. + + Returns: + str: The path to the checkpoint file for the model. + """ + return os.path.join(self._model_dir, f"{model_id}")