Skip to content

Commit 8d00ce7

Browse files
RkGritCRZbulabula
authored andcommitted
support various pipeline Interfaces and support arima with sktime package (#16861)
1 parent a8bfbcc commit 8d00ce7

File tree

12 files changed

+111
-50
lines changed

12 files changed

+111
-50
lines changed

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE
3131
from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher
3232
from iotdb.ainode.core.inference.inference_request import InferenceRequest
33+
from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline, ClassificationPipeline, ChatPipeline
3334
from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline
3435
from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import (
3536
BasicRequestScheduler,
@@ -116,11 +117,24 @@ def _step(self):
116117

117118
for requests in grouped_requests:
118119
batch_inputs = self._batcher.batch_request(requests).to(self.device)
119-
batch_output = self._inference_pipeline.infer(
120-
batch_inputs,
121-
predict_length=requests[0].max_new_tokens,
122-
revin=True,
123-
)
120+
if isinstance(self._inference_pipeline, ForecastPipeline):
121+
batch_output = self._inference_pipeline.forecast(
122+
batch_inputs,
123+
predict_length=requests[0].max_new_tokens,
124+
revin=True,
125+
)
126+
elif isinstance(self._inference_pipeline, ClassificationPipeline):
127+
batch_output = self._inference_pipeline.classify(
128+
batch_inputs,
129+
# more infer kwargs can be added here
130+
)
131+
elif isinstance(self._inference_pipeline, ChatPipeline):
132+
batch_output = self._inference_pipeline.chat(
133+
batch_inputs,
134+
# more infer kwargs can be added here
135+
)
136+
else:
137+
self._logger.error("[Inference] Unsupported pipeline type.")
124138
offset = 0
125139
for request in requests:
126140
request.output_tensor = request.output_tensor.to(self.device)

iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
# under the License.
1717
#
1818

19-
from abc import ABC
19+
from abc import ABC, abstractmethod
2020

2121
import torch
2222

23-
from iotdb.ainode.core.exception import InferenceModelInternalError
2423
from iotdb.ainode.core.model.model_loader import load_model
2524

2625

@@ -48,12 +47,9 @@ def __init__(self, model_info, **model_kwargs):
4847
super().__init__(model_info, model_kwargs=model_kwargs)
4948

5049
def _preprocess(self, inputs):
51-
if len(inputs.shape) != 2:
52-
raise InferenceModelInternalError(
53-
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
54-
)
5550
return inputs
5651

52+
@abstractmethod
5753
def forecast(self, inputs, **infer_kwargs):
5854
pass
5955

@@ -68,6 +64,7 @@ def __init__(self, model_info, **model_kwargs):
6864
def _preprocess(self, inputs):
6965
return inputs
7066

67+
@abstractmethod
7168
def classify(self, inputs, **kwargs):
7269
pass
7370

@@ -82,6 +79,7 @@ def __init__(self, model_info, **model_kwargs):
8279
def _preprocess(self, inputs):
8380
return inputs
8481

82+
@abstractmethod
8583
def chat(self, inputs, **kwargs):
8684
pass
8785

iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs):
3232
if model_info.model_type == "sktime":
3333
from iotdb.ainode.core.model.sktime.pipeline_sktime import SktimePipeline
34+
3435
pipeline_cls = SktimePipeline
3536
elif model_info.category == ModelCategory.BUILTIN:
3637
module_name = (

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
InferenceRequestProxy,
3636
)
3737
from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline
38+
from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline, ClassificationPipeline, ChatPipeline
3839
from iotdb.ainode.core.inference.pool_controller import PoolController
3940
from iotdb.ainode.core.inference.utils import generate_req_id
4041
from iotdb.ainode.core.log import Logger
@@ -210,9 +211,16 @@ def _run(
210211
else:
211212
model_info = self._model_manager.get_model_info(model_id)
212213
inference_pipeline = load_pipeline(model_info, device="cpu")
213-
outputs = inference_pipeline.infer(
214-
inputs, predict_length=predict_length, **inference_attrs
215-
)
214+
if isinstance(inference_pipeline, ForecastPipeline):
215+
outputs = inference_pipeline.forecast(
216+
inputs, predict_length=predict_length, **inference_attrs
217+
)
218+
elif isinstance(inference_pipeline, ClassificationPipeline):
219+
outputs = inference_pipeline.classify(inputs)
220+
elif isinstance(inference_pipeline, ChatPipeline):
221+
outputs = inference_pipeline.chat(inputs)
222+
else:
223+
logger.error("[Inference] Unsupported pipeline type.")
216224
outputs = convert_to_binary(pd.DataFrame(outputs[0]))
217225

218226
# construct response

iotdb-core/ainode/iotdb/ainode/core/manager/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int:
9191
system_res = evaluate_system_resources(device)
9292
free_mem = system_res["free_mem"]
9393

94-
mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO
94+
mem_usage = (
95+
MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO
96+
)
9597
size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage)
9698
if size <= 0:
9799
logger.error(

iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,23 @@
33
"model_id": "arima",
44
"predict_length": 1,
55
"order": [1, 0, 0],
6-
"season_length": 1,
7-
"seasonal_order": [0, 0, 0],
8-
"include_mean": true,
9-
"include_drift": false,
10-
"biasadj": false,
11-
"method": "CSS-ML"
6+
"seasonal_order": [0, 0, 0, 0],
7+
"start_params": null,
8+
"method": "lbfgs",
9+
"maxiter": 50,
10+
"suppress_warnings": false,
11+
"out_of_sample_size": 0,
12+
"scoring": "mse",
13+
"scoring_args": null,
14+
"trend": null,
15+
"with_intercept": true,
16+
"time_varying_regression": false,
17+
"enforce_stationarity": true,
18+
"enforce_invertibility": true,
19+
"simple_differencing": false,
20+
"measurement_error": false,
21+
"mle_regression": true,
22+
"hamilton_representation": false,
23+
"concentrate_scale": false
1224
}
1325

iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -176,21 +176,40 @@ def parse(self, string_value: str):
176176
"ARIMA": {
177177
"predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
178178
"order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int),
179-
"season_length": AttributeConfig("season_length", 1, "int", 1, 5000),
180179
"seasonal_order": AttributeConfig(
181-
"seasonal_order", (0, 0, 0), "tuple", value_type=int
180+
"seasonal_order", (0, 0, 0, 0), "tuple", value_type=int
182181
),
183-
"include_mean": AttributeConfig("include_mean", True, "bool"),
184-
"include_drift": AttributeConfig("include_drift", False, "bool"),
185-
"include_constant": AttributeConfig("include_constant", None, "bool"),
186-
"blambda": AttributeConfig("blambda", None, "float"),
187-
"biasadj": AttributeConfig("biasadj", False, "bool"),
182+
"start_params": AttributeConfig("start_params", None, "str"),
188183
"method": AttributeConfig(
189184
"method",
190-
"CSS-ML",
185+
"lbfgs",
191186
"str",
192-
choices=["CSS-ML", "ML", "CSS"],
187+
choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"],
193188
),
189+
"maxiter": AttributeConfig("maxiter", 50, "int", 1, 5000),
190+
"suppress_warnings": AttributeConfig("suppress_warnings", False, "bool"),
191+
"out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000),
192+
"scoring": AttributeConfig(
193+
"scoring",
194+
"mse",
195+
"str",
196+
choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"],
197+
),
198+
"scoring_args": AttributeConfig("scoring_args", None, "str"),
199+
"trend": AttributeConfig("trend", None, "str"),
200+
"with_intercept": AttributeConfig("with_intercept", True, "bool"),
201+
"time_varying_regression": AttributeConfig(
202+
"time_varying_regression", False, "bool"
203+
),
204+
"enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"),
205+
"enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"),
206+
"simple_differencing": AttributeConfig("simple_differencing", False, "bool"),
207+
"measurement_error": AttributeConfig("measurement_error", False, "bool"),
208+
"mle_regression": AttributeConfig("mle_regression", True, "bool"),
209+
"hamilton_representation": AttributeConfig(
210+
"hamilton_representation", False, "bool"
211+
),
212+
"concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"),
194213
},
195214
"STL_FORECASTER": {
196215
"predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),

iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sklearn.preprocessing import MinMaxScaler
2525
from sktime.detection.hmm_learn import GMMHMM, GaussianHMM
2626
from sktime.detection.stray import STRAY
27-
from statsforecast.models import ARIMA
27+
from sktime.forecasting.arima import ARIMA
2828
from sktime.forecasting.exp_smoothing import ExponentialSmoothing
2929
from sktime.forecasting.naive import NaiveForecaster
3030
from sktime.forecasting.trend import STLForecaster
@@ -59,12 +59,11 @@ class ForecastingModel(SktimeModel):
5959
def generate(self, data, **kwargs):
6060
"""Execute forecasting"""
6161
try:
62-
predict_length = kwargs.get("predict_length", self._attributes["predict_length"])
62+
predict_length = kwargs.get(
63+
"predict_length", self._attributes["predict_length"]
64+
)
6365
self._model.fit(data)
64-
if isinstance(self._model, ARIMA):
65-
output = self._model.predict(h=predict_length)['mean']
66-
else:
67-
output = self._model.predict(fh=range(predict_length))
66+
output = self._model.predict(fh=range(predict_length))
6867
return np.array(output, dtype=np.float64)
6968
except Exception as e:
7069
raise InferenceModelInternalError(str(e))
@@ -92,7 +91,7 @@ class ArimaModel(ForecastingModel):
9291
def __init__(self, attributes: Dict[str, Any]):
9392
super().__init__(attributes)
9493
self._model = ARIMA(
95-
**{k: v for k, v in attributes.items() if k != "predict_length" and v is not None}
94+
**{k: v for k, v in attributes.items() if k != "predict_length"}
9695
)
9796

9897

@@ -147,9 +146,7 @@ class STRAYModel(DetectionModel):
147146

148147
def __init__(self, attributes: Dict[str, Any]):
149148
super().__init__(attributes)
150-
self._model = STRAY(
151-
**{k: v for k, v in attributes.items() if v is not None}
152-
)
149+
self._model = STRAY(**{k: v for k, v in attributes.items() if v is not None})
153150

154151
def generate(self, data, **kwargs):
155152
"""STRAY requires special handling: normalize first"""

iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@
2020
import pandas as pd
2121
import torch
2222

23-
from iotdb.ainode.core.inference.pipeline.basic_pipeline import BasicPipeline
23+
from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline
2424

2525

26-
class SktimePipeline(BasicPipeline):
26+
class SktimePipeline(ForecastPipeline):
2727
def __init__(self, model_info, **model_kwargs):
28+
model_kwargs.pop("device", None) # sktime models run on CPU
2829
super().__init__(model_info, model_kwargs=model_kwargs)
29-
model_kwargs.pop("device", None)
3030

3131
def _preprocess(self, inputs):
32-
return super()._preprocess(inputs)
32+
return inputs
3333

34-
def infer(self, inputs, **infer_kwargs):
34+
def forecast(self, inputs, **infer_kwargs):
3535
predict_length = infer_kwargs.get("predict_length", 96)
3636
input_ids = self._preprocess(inputs)
3737

iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,21 @@
1919
import torch
2020

2121
from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline
22+
from iotdb.ainode.core.exception import InferenceModelInternalError
2223

2324

2425
class SundialPipeline(ForecastPipeline):
2526
def __init__(self, model_info, **model_kwargs):
2627
super().__init__(model_info, model_kwargs=model_kwargs)
2728

2829
def _preprocess(self, inputs):
29-
return super()._preprocess(inputs)
30+
if len(inputs.shape) != 2:
31+
raise InferenceModelInternalError(
32+
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
33+
)
34+
return inputs
3035

31-
def infer(self, inputs, **infer_kwargs):
36+
def forecast(self, inputs, **infer_kwargs):
3237
predict_length = infer_kwargs.get("predict_length", 96)
3338
num_samples = infer_kwargs.get("num_samples", 10)
3439
revin = infer_kwargs.get("revin", True)

0 commit comments

Comments
 (0)