Skip to content

Commit e83718b

Browse files
authored
[AINode] Add a model ckpt path retrieve interface (#15689)
1 parent eeca531 commit e83718b

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

iotdb-core/ainode/ainode/core/manager/model_manager.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,18 @@ def load_model(self, model_id: str, acceleration: bool = False) -> Callable:
100100
logger.info(f"load model {model_id}")
101101
return self.model_storage.load_model(model_id, acceleration)
102102

103+
def get_ckpt_path(self, model_id: str) -> str:
104+
"""
105+
Get the checkpoint path for a given model ID.
106+
107+
Args:
108+
model_id (str): The ID of the model.
109+
110+
Returns:
111+
str: The path to the checkpoint file for the model.
112+
"""
113+
return self.model_storage.get_ckpt_path(model_id)
114+
103115
@staticmethod
104116
def load_built_in_model(model_id: str, attributes: {}):
105117
model_id = model_id.lower()

iotdb-core/ainode/ainode/core/model/model_storage.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,15 @@ def delete_model(self, model_id: str) -> None:
119119
def _remove_from_cache(self, file_path: str) -> None:
120120
if file_path in self._model_cache:
121121
del self._model_cache[file_path]
122+
123+
def get_ckpt_path(self, model_id: str) -> str:
124+
"""
125+
Get the checkpoint path for a given model ID.
126+
127+
Args:
128+
model_id (str): The ID of the model.
129+
130+
Returns:
131+
str: The path to the checkpoint file for the model.
132+
"""
133+
return os.path.join(self._model_dir, f"{model_id}")

0 commit comments

Comments
 (0)