Skip to content

Commit b04cd46

Browse files
authored
[AINode] Fix built-in model inference & support user parameters (#15868)
1 parent 5ef583c commit b04cd46

File tree

4 files changed

+30
-20
lines changed

4 files changed

+30
-20
lines changed

iotdb-core/ainode/ainode/core/manager/inference_manager.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def infer(self, full_data, predict_length=96, **_):
8484

8585

8686
class BuiltInStrategy(InferenceStrategy):
87-
def infer(self, full_data, **_):
87+
def infer(self, full_data):
8888
data = pd.DataFrame(full_data[1]).T
8989
output = self.model.inference(data)
9090
df = pd.DataFrame(output)
@@ -121,20 +121,19 @@ def infer(self, full_data, window_interval=None, window_step=None, **kwargs):
121121
return [convert_to_binary(df) for df in results]
122122

123123

124-
def _get_strategy(model_id, model):
125-
if isinstance(model, TimerForPrediction):
126-
return TimerXLStrategy(model)
127-
if isinstance(model, SundialForPrediction):
128-
return SundialStrategy(model)
129-
if model_id.startswith("_"):
130-
return BuiltInStrategy(model)
131-
return RegisteredStrategy(model)
132-
133-
134124
class InferenceManager:
135125
def __init__(self, model_manager: ModelManager):
136126
self.model_manager = model_manager
137127

128+
def _get_strategy(self, model_id, model):
129+
if isinstance(model, TimerForPrediction):
130+
return TimerXLStrategy(model)
131+
if isinstance(model, SundialForPrediction):
132+
return SundialStrategy(model)
133+
if self.model_manager.model_storage._is_built_in(model_id):
134+
return BuiltInStrategy(model)
135+
return RegisteredStrategy(model)
136+
138137
def _run(
139138
self,
140139
req,
@@ -156,11 +155,11 @@ def _run(
156155

157156
# load model
158157
accel = str(inference_attrs.get("acceleration", "")).lower() == "true"
159-
model = self.model_manager.load_model(model_id, accel)
158+
model = self.model_manager.load_model(model_id, inference_attrs, accel)
160159

161160
# inference by strategy
162-
strategy = _get_strategy(model_id, model)
163-
outputs = strategy.infer(full_data, **inference_attrs)
161+
strategy = self._get_strategy(model_id, model)
162+
outputs = strategy.infer(full_data)
164163

165164
# construct response
166165
status = get_status(TSStatusCode.SUCCESS_STATUS)

iotdb-core/ainode/ainode/core/manager/model_manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18-
from typing import Callable
18+
from typing import Callable, Dict
1919

2020
from torch import nn
2121
from yaml import YAMLError
@@ -97,13 +97,17 @@ def delete_model(self, req: TDeleteModelReq) -> TSStatus:
9797
logger.warning(e)
9898
return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
9999

100-
def load_model(self, model_id: str, acceleration: bool = False) -> Callable:
100+
def load_model(
101+
self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool = False
102+
) -> Callable:
101103
"""
102104
Load the model with the given model_id.
103105
"""
104106
logger.info(f"Load model {model_id}")
105107
try:
106-
model = self.model_storage.load_model(model_id, acceleration)
108+
model = self.model_storage.load_model(
109+
model_id, inference_attrs, acceleration
110+
)
107111
logger.info(f"Model {model_id} loaded")
108112
return model
109113
except Exception as e:

iotdb-core/ainode/ainode/core/model/built_in_model_factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ def get_model_attributes(model_type: BuiltInModelType):
123123
return attribute_map
124124

125125

126-
def fetch_built_in_model(model_type: BuiltInModelType, model_dir) -> Callable:
126+
def fetch_built_in_model(
127+
model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str]
128+
) -> Callable:
127129
"""
128130
Fetch the built-in model according to its id and directory, not that this directory only contains model weights and config.
129131
Args:
@@ -132,7 +134,9 @@ def fetch_built_in_model(model_type: BuiltInModelType, model_dir) -> Callable:
132134
Returns:
133135
model: the built-in model
134136
"""
135-
attributes = get_model_attributes(model_type)
137+
default_attributes = get_model_attributes(model_type)
138+
# parse the attributes from inference_attrs
139+
attributes = parse_attribute(inference_attrs, default_attributes)
136140

137141
# build the built-in model
138142
if model_type == BuiltInModelType.ARIMA:

iotdb-core/ainode/ainode/core/model/model_storage.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,9 @@ def _is_built_in(self, model_id: str) -> bool:
262262
or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED
263263
)
264264

265-
def load_model(self, model_id: str, acceleration: bool) -> Callable:
265+
def load_model(
266+
self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool
267+
) -> Callable:
266268
"""
267269
Load a model with automatic detection of .safetensors or .pt format
268270
@@ -275,6 +277,7 @@ def load_model(self, model_id: str, acceleration: bool) -> Callable:
275277
return fetch_built_in_model(
276278
get_built_in_model_type(self._model_info_map[model_id].model_type),
277279
model_dir,
280+
inference_attrs,
278281
)
279282
else:
280283
# load the user-defined model

0 commit comments

Comments
 (0)