Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions iotdb-core/ainode/ainode/core/manager/inference_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def infer(self, full_data, predict_length=96, **_):


class BuiltInStrategy(InferenceStrategy):
def infer(self, full_data, **_):
def infer(self, full_data):
data = pd.DataFrame(full_data[1]).T
output = self.model.inference(data)
df = pd.DataFrame(output)
Expand Down Expand Up @@ -121,20 +121,19 @@ def infer(self, full_data, window_interval=None, window_step=None, **kwargs):
return [convert_to_binary(df) for df in results]


def _get_strategy(model_id, model):
if isinstance(model, TimerForPrediction):
return TimerXLStrategy(model)
if isinstance(model, SundialForPrediction):
return SundialStrategy(model)
if model_id.startswith("_"):
return BuiltInStrategy(model)
return RegisteredStrategy(model)


class InferenceManager:
def __init__(self, model_manager: ModelManager):
self.model_manager = model_manager

def _get_strategy(self, model_id, model):
if isinstance(model, TimerForPrediction):
return TimerXLStrategy(model)
if isinstance(model, SundialForPrediction):
return SundialStrategy(model)
if self.model_manager.model_storage._is_built_in(model_id):
return BuiltInStrategy(model)
return RegisteredStrategy(model)

def _run(
self,
req,
Expand All @@ -156,11 +155,11 @@ def _run(

# load model
accel = str(inference_attrs.get("acceleration", "")).lower() == "true"
model = self.model_manager.load_model(model_id, accel)
model = self.model_manager.load_model(model_id, inference_attrs, accel)

# inference by strategy
strategy = _get_strategy(model_id, model)
outputs = strategy.infer(full_data, **inference_attrs)
strategy = self._get_strategy(model_id, model)
outputs = strategy.infer(full_data)

# construct response
status = get_status(TSStatusCode.SUCCESS_STATUS)
Expand Down
10 changes: 7 additions & 3 deletions iotdb-core/ainode/ainode/core/manager/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
#
from typing import Callable
from typing import Callable, Dict

from torch import nn
from yaml import YAMLError
Expand Down Expand Up @@ -97,13 +97,17 @@ def delete_model(self, req: TDeleteModelReq) -> TSStatus:
logger.warning(e)
return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))

def load_model(self, model_id: str, acceleration: bool = False) -> Callable:
def load_model(
self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool = False
) -> Callable:
"""
Load the model with the given model_id.
"""
logger.info(f"Load model {model_id}")
try:
model = self.model_storage.load_model(model_id, acceleration)
model = self.model_storage.load_model(
model_id, inference_attrs, acceleration
)
logger.info(f"Model {model_id} loaded")
return model
except Exception as e:
Expand Down
8 changes: 6 additions & 2 deletions iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def get_model_attributes(model_type: BuiltInModelType):
return attribute_map


def fetch_built_in_model(model_type: BuiltInModelType, model_dir) -> Callable:
def fetch_built_in_model(
model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str]
) -> Callable:
"""
Fetch the built-in model according to its id and directory, not that this directory only contains model weights and config.
Args:
Expand All @@ -132,7 +134,9 @@ def fetch_built_in_model(model_type: BuiltInModelType, model_dir) -> Callable:
Returns:
model: the built-in model
"""
attributes = get_model_attributes(model_type)
default_attributes = get_model_attributes(model_type)
# parse the attributes from inference_attrs
attributes = parse_attribute(inference_attrs, default_attributes)

# build the built-in model
if model_type == BuiltInModelType.ARIMA:
Expand Down
5 changes: 4 additions & 1 deletion iotdb-core/ainode/ainode/core/model/model_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def _is_built_in(self, model_id: str) -> bool:
or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED
)

def load_model(self, model_id: str, acceleration: bool) -> Callable:
def load_model(
self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool
) -> Callable:
"""
Load a model with automatic detection of .safetensors or .pt format

Expand All @@ -275,6 +277,7 @@ def load_model(self, model_id: str, acceleration: bool) -> Callable:
return fetch_built_in_model(
get_built_in_model_type(self._model_info_map[model_id].model_type),
model_dir,
inference_attrs,
)
else:
# load the user-defined model
Expand Down
Loading