Skip to content

Commit f1870cb

Browse files
authored
[AINode] Limit max inference length (apache#15982)
1 parent 1d3a35e commit f1870cb

File tree

7 files changed

+39
-10
lines changed

7 files changed

+39
-10
lines changed

iotdb-core/ainode/ainode/core/ainode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def start(self):
134134
raise e
135135

136136
# Start the RPC service
137-
self._rpc_handler = AINodeRPCServiceHandler(aiNode=self)
137+
self._rpc_handler = AINodeRPCServiceHandler(ainode=self)
138138
self._rpc_service = AINodeRPCService(self._rpc_handler)
139139
self._rpc_service.start()
140140
self._rpc_service.join(1)

iotdb-core/ainode/ainode/core/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
AINODE_CONF_GIT_FILE_NAME,
3232
AINODE_CONF_POM_FILE_NAME,
3333
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
34+
AINODE_INFERENCE_MAX_PREDICT_LENGTH,
3435
AINODE_LOG_DIR,
3536
AINODE_MODELS_DIR,
3637
AINODE_ROOT_CONF_DIRECTORY_NAME,
@@ -72,6 +73,9 @@ def __init__(self):
7273
self._ain_inference_batch_interval_in_ms: int = (
7374
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
7475
)
76+
self._ain_inference_max_predict_length: int = (
77+
AINODE_INFERENCE_MAX_PREDICT_LENGTH
78+
)
7579

7680
# log directory
7781
self._ain_logs_dir: str = AINODE_LOG_DIR
@@ -140,6 +144,14 @@ def set_ain_inference_batch_interval_in_ms(
140144
) -> None:
141145
self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms
142146

147+
def get_ain_inference_max_predict_length(self) -> int:
148+
return self._ain_inference_max_predict_length
149+
150+
def set_ain_inference_max_predict_length(
151+
self, ain_inference_max_predict_length: int
152+
) -> None:
153+
self._ain_inference_max_predict_length = ain_inference_max_predict_length
154+
143155
def get_ain_logs_dir(self) -> str:
144156
return self._ain_logs_dir
145157

iotdb-core/ainode/ainode/core/constant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
# AINode inference configuration
5252
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
53+
AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880
5354

5455
# AINode folder structure
5556
AINODE_ROOT_DIR = os.path.dirname(

iotdb-core/ainode/ainode/core/inference/inference_request_pool.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
import torch
2525
import torch.multiprocessing as mp
26-
from transformers import PretrainedConfig, PreTrainedModel
26+
from transformers import PretrainedConfig
2727

2828
from ainode.core.config import AINodeDescriptor
2929
from ainode.core.inference.inference_request import InferenceRequest
@@ -46,7 +46,7 @@ class InferenceRequestPool(mp.Process):
4646
def __init__(
4747
self,
4848
pool_id: int,
49-
model_id: int,
49+
model_id: str,
5050
config: PretrainedConfig,
5151
request_queue: mp.Queue,
5252
result_queue: mp.Queue,
@@ -58,6 +58,7 @@ def __init__(
5858
self.config = config
5959
self.pool_kwargs = pool_kwargs
6060
self.model = None
61+
self._model_manager = None
6162
self.device = None
6263

6364
# TODO: A scheduler is necessary for better handling following queues

iotdb-core/ainode/ainode/core/inference/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
2323

2424

25-
def _generate_req_id(length=10, charset=string.ascii_letters + string.digits) -> str:
25+
def generate_req_id(length=10, charset=string.ascii_letters + string.digits) -> str:
2626
"""
2727
Generate a random req_id string of specified length.
2828
The length is 10 by default, with 10^{17} possible combinations.

iotdb-core/ainode/ainode/core/manager/inference_manager.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ainode.core.exception import (
3131
InferenceModelInternalError,
3232
InvalidWindowArgumentError,
33+
NumericalRangeException,
3334
runtime_error_extractor,
3435
)
3536
from ainode.core.inference.inference_request import (
@@ -40,7 +41,7 @@
4041
from ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
4142
TimerSundialInferencePipeline,
4243
)
43-
from ainode.core.inference.utils import _generate_req_id
44+
from ainode.core.inference.utils import generate_req_id
4445
from ainode.core.log import Logger
4546
from ainode.core.manager.model_manager import ModelManager
4647
from ainode.core.model.sundial.configuration_sundial import SundialConfig
@@ -214,6 +215,20 @@ def _run(
214215
full_data = deserializer(raw)
215216
inference_attrs = extract_attrs(req)
216217

218+
predict_length = inference_attrs.get("predict_length", 96)
219+
if (
220+
predict_length
221+
> AINodeDescriptor().get_config().get_ain_inference_max_predict_length()
222+
):
223+
raise NumericalRangeException(
224+
"output_length",
225+
1,
226+
AINodeDescriptor()
227+
.get_config()
228+
.get_ain_inference_max_predict_length(),
229+
predict_length,
230+
)
231+
217232
if model_id == self.ACCELERATE_MODEL_ID and self.DEFAULT_POOL_SIZE > 0:
218233
# TODO: Logic in this branch shall handle all LTSM inferences
219234
# TODO: TSBlock -> Tensor codes should be unified
@@ -223,10 +238,10 @@ def _run(
223238
# the inputs should be on CPU before passing to the inference request
224239
inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
225240
infer_req = InferenceRequest(
226-
req_id=_generate_req_id(),
241+
req_id=generate_req_id(),
227242
inputs=inputs,
228243
inference_pipeline=TimerSundialInferencePipeline(SundialConfig()),
229-
max_new_tokens=inference_attrs.get("predict_length", 96),
244+
max_new_tokens=predict_length,
230245
)
231246
infer_proxy = InferenceRequestProxy(infer_req.req_id)
232247
with self._result_wrapper_lock:

iotdb-core/ainode/ainode/core/rpc/handler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@
4141

4242

4343
class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
44-
def __init__(self, aiNode):
45-
self._aiNode = aiNode
44+
def __init__(self, ainode):
45+
self._ainode = ainode
4646
self._model_manager = ModelManager()
4747
self._inference_manager = InferenceManager()
4848

4949
def stopAINode(self) -> TSStatus:
50-
self._aiNode.stop()
50+
self._ainode.stop()
5151
return get_status(TSStatusCode.SUCCESS_STATUS, "AINode stopped successfully.")
5252

5353
def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:

0 commit comments

Comments
 (0)