Skip to content

Commit 7384701

Browse files
authored
[AINode] Support dynamic adjustment for pool size (apache#16079)
1 parent 9cd5fa4 commit 7384701

File tree

10 files changed

+367
-44
lines changed

10 files changed

+367
-44
lines changed

iotdb-core/ainode/ainode/core/ai_node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,6 @@ def stop(self):
162162
if self._rpc_service:
163163
self._rpc_service.stop()
164164
self._rpc_service.join(1)
165+
if self._rpc_service.is_alive():
166+
logger.warning("RPC service thread failed to stop in time.")
165167
logger.info("IoTDB-AINode has successfully stopped.")

iotdb-core/ainode/ainode/core/config.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131
AINODE_CONF_GIT_FILE_NAME,
3232
AINODE_CONF_POM_FILE_NAME,
3333
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
34+
AINODE_INFERENCE_EXTRA_MEMORY_RATIO,
3435
AINODE_INFERENCE_MAX_PREDICT_LENGTH,
36+
AINODE_INFERENCE_MEMORY_USAGE_RATIO,
37+
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP,
3538
AINODE_LOG_DIR,
3639
AINODE_MODELS_DIR,
3740
AINODE_ROOT_CONF_DIRECTORY_NAME,
@@ -76,7 +79,15 @@ def __init__(self):
7679
self._ain_inference_max_predict_length: int = (
7780
AINODE_INFERENCE_MAX_PREDICT_LENGTH
7881
)
79-
82+
self._ain_inference_model_mem_usage_map: dict[str, int] = (
83+
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP
84+
)
85+
self._ain_inference_memory_usage_ratio: float = (
86+
AINODE_INFERENCE_MEMORY_USAGE_RATIO
87+
)
88+
self._ain_inference_extra_memory_ratio: float = (
89+
AINODE_INFERENCE_EXTRA_MEMORY_RATIO
90+
)
8091
# log directory
8192
self._ain_logs_dir: str = AINODE_LOG_DIR
8293

@@ -152,6 +163,30 @@ def set_ain_inference_max_predict_length(
152163
) -> None:
153164
self._ain_inference_max_predict_length = ain_inference_max_predict_length
154165

166+
def get_ain_inference_model_mem_usage_map(self) -> dict[str, int]:
167+
return self._ain_inference_model_mem_usage_map
168+
169+
def set_ain_inference_model_mem_usage_map(
170+
self, ain_inference_model_mem_usage_map: dict[str, int]
171+
) -> None:
172+
self._ain_inference_model_mem_usage_map = ain_inference_model_mem_usage_map
173+
174+
def get_ain_inference_memory_usage_ratio(self) -> float:
175+
return self._ain_inference_memory_usage_ratio
176+
177+
def set_ain_inference_memory_usage_ratio(
178+
self, ain_inference_memory_usage_ratio: float
179+
) -> None:
180+
self._ain_inference_memory_usage_ratio = ain_inference_memory_usage_ratio
181+
182+
def get_ain_inference_extra_memory_ratio(self) -> float:
183+
return self._ain_inference_extra_memory_ratio
184+
185+
def set_ain_inference_extra_memory_ratio(
186+
self, ain_inference_extra_memory_ratio: float
187+
) -> None:
188+
self._ain_inference_extra_memory_ratio = ain_inference_extra_memory_ratio
189+
155190
def get_ain_logs_dir(self) -> str:
156191
return self._ain_logs_dir
157192

@@ -294,6 +329,21 @@ def _load_config_from_file(self) -> None:
294329
int(file_configs["ain_inference_batch_interval_in_ms"])
295330
)
296331

332+
if "ain_inference_model_mem_usage_map" in config_keys:
333+
self._config.set_ain_inference_model_mem_usage_map(
334+
eval(file_configs["ain_inference_model_mem_usage_map"])
335+
)
336+
337+
if "ain_inference_memory_usage_ratio" in config_keys:
338+
self._config.set_ain_inference_memory_usage_ratio(
339+
float(file_configs["ain_inference_memory_usage_ratio"])
340+
)
341+
342+
if "ain_inference_extra_memory_ratio" in config_keys:
343+
self._config.set_ain_inference_extra_memory_ratio(
344+
float(file_configs["ain_inference_extra_memory_ratio"])
345+
)
346+
297347
if "ain_models_dir" in config_keys:
298348
self._config.set_ain_models_dir(file_configs["ain_models_dir"])
299349

iotdb-core/ainode/ainode/core/constant.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from enum import Enum
2222
from typing import List
2323

24+
from ainode.core.model.model_info import BuiltInModelType
2425
from ainode.thrift.common.ttypes import TEndPoint
2526

2627
AINODE_VERSION_INFO = "UNKNOWN"
@@ -51,6 +52,14 @@
5152
# AINode inference configuration
5253
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
5354
AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880
55+
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = {
56+
BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB
57+
BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB
58+
} # the memory usage of each model in bytes
59+
AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference
60+
AINODE_INFERENCE_EXTRA_MEMORY_RATIO = (
61+
1.2 # the overhead ratio for inference, used to estimate the pool size
62+
)
5463

5564
# AINode folder structure
5665
AINODE_ROOT_DIR = os.path.dirname(

iotdb-core/ainode/ainode/core/inference/inference_request_pool.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
config: PretrainedConfig,
5151
request_queue: mp.Queue,
5252
result_queue: mp.Queue,
53+
ready_event,
5354
**pool_kwargs,
5455
):
5556
super().__init__()
@@ -59,11 +60,8 @@ def __init__(
5960
self.pool_kwargs = pool_kwargs
6061
self.model = None
6162
self._model_manager = None
62-
# TODO: Assign device immediately when the pool is created
6363
self.device = None
64-
self.logger = Logger(
65-
INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
66-
)
64+
self.ready_event = ready_event
6765

6866
self._threads = []
6967
self._waiting_queue = request_queue # Requests that are waiting to be processed
@@ -128,15 +126,25 @@ def _step(self):
128126
requests = self._scheduler.schedule_step()
129127
# TODO: We need a batcher to accelerate the concurrent inference
130128
for request in requests:
131-
request.inputs = request.inputs.to(self.device)
132-
output = self.model.generate(
133-
request.inputs,
134-
max_new_tokens=request.max_new_tokens,
135-
num_samples=10,
136-
revin=True,
137-
)
138-
request.output_tensor = request.output_tensor.to(self.device)
139-
request.write_step_output(output[0].mean(dim=0))
129+
if self.model_id == "sundial":
130+
request.inputs = request.inputs.to(self.device)
131+
output = self.model.generate(
132+
request.inputs,
133+
max_new_tokens=request.max_new_tokens,
134+
num_samples=10,
135+
revin=True,
136+
)
137+
request.output_tensor = request.output_tensor.to(self.device)
138+
request.write_step_output(output[0].mean(dim=0))
139+
elif self.model_id == "timer_xl":
140+
request.inputs = request.inputs.to(self.device)
141+
output = self.model.generate(
142+
request.inputs,
143+
max_new_tokens=request.max_new_tokens,
144+
revin=True,
145+
)
146+
request.output_tensor = request.output_tensor.to(self.device)
147+
request.write_step_output(output[0])
140148
request.inference_pipeline.post_decode()
141149
if request.is_finished():
142150
request.inference_pipeline.post_inference()
@@ -160,8 +168,12 @@ def _requests_execute_loop(self):
160168
def run(self):
161169
self._model_manager = ModelManager()
162170
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
171+
self.logger = Logger(
172+
INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
173+
)
163174
self._scheduler.device = self.device
164175
self.model = self._model_manager.load_model(self.model_id, {}).to(self.device)
176+
self.ready_event.set()
165177

166178
# self._warm_up_and_estimate_memory()
167179

@@ -183,4 +195,3 @@ def run(self):
183195

184196
def stop(self):
185197
self._stop_event.set()
186-
self.logger.info(f"[Inference][Pool-{self.pool_id}] stop() called")

iotdb-core/ainode/ainode/core/inference/scheduler/basic_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def memory_is_available(self):
6666
)
6767
logger.debug(
6868
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] "
69-
f"Memory used: {used} bytes, Max memory: {self.max_memory_bytes} bytes"
69+
f"Memory used: {used/1024**2:.2f} MB, Max memory: {self.max_memory_bytes/1024**2:.2f} MB"
7070
)
7171
return used < self.max_memory_bytes
7272

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
import torch
20+
21+
from ainode.core.exception import InferenceModelInternalError
22+
from ainode.core.inference.strategy.abstract_inference_pipeline import (
23+
AbstractInferencePipeline,
24+
)
25+
from ainode.core.model.timerxl.configuration_timer import TimerConfig
26+
27+
28+
class TimerXLInferencePipeline(AbstractInferencePipeline):
29+
"""
30+
Strategy for Timer-XL model inference.
31+
"""
32+
33+
def __init__(self, model_config: TimerConfig, **infer_kwargs):
34+
super().__init__(model_config, infer_kwargs=infer_kwargs)
35+
36+
def preprocess_inputs(self, inputs: torch.Tensor):
37+
super().preprocess_inputs(inputs)
38+
if len(inputs.shape) != 2:
39+
raise InferenceModelInternalError(
40+
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
41+
)
42+
# TODO: Disassemble and adapt with TimerXL's ts_generation_mixin.py
43+
return inputs
44+
45+
def post_decode(self):
46+
# TODO: Disassemble and adapt with TimerXL's ts_generation_mixin.py
47+
pass
48+
49+
def post_inference(self):
50+
# TODO: Disassemble and adapt with TimerXL's ts_generation_mixin.py
51+
pass

0 commit comments

Comments
 (0)