Skip to content

Commit ec9a507

Browse files
committed
Seems finish
1 parent 1d852bf commit ec9a507

File tree

16 files changed

+110
-130
lines changed

16 files changed

+110
-130
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
package org.apache.iotdb.ainode.it;
2121

2222
import org.apache.iotdb.it.env.EnvFactory;
23+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
24+
import org.apache.iotdb.itbase.category.AIClusterIT;
2325
import org.apache.iotdb.itbase.env.BaseEnv;
2426

2527
import com.google.common.collect.ImmutableMap;
@@ -28,6 +30,8 @@
2830
import org.junit.Assert;
2931
import org.junit.BeforeClass;
3032
import org.junit.Test;
33+
import org.junit.experimental.categories.Category;
34+
import org.junit.runner.RunWith;
3135
import org.slf4j.Logger;
3236
import org.slf4j.LoggerFactory;
3337

@@ -42,6 +46,8 @@
4246

4347
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
4448

49+
@RunWith(IoTDBTestRunner.class)
50+
@Category({AIClusterIT.class})
4551
public class AINodeConcurrentInferenceIT {
4652

4753
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class);

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
)
3737
from iotdb.ainode.core.log import Logger
3838
from iotdb.ainode.core.manager.model_manager import ModelManager
39+
from iotdb.ainode.core.model.model_enums import BuiltInModelType
40+
from iotdb.ainode.core.model.model_info import ModelInfo
3941
from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device
4042

4143

@@ -58,7 +60,7 @@ class InferenceRequestPool(mp.Process):
5860
def __init__(
5961
self,
6062
pool_id: int,
61-
model_id: str,
63+
model_info: ModelInfo,
6264
device: str,
6365
config: PretrainedConfig,
6466
request_queue: mp.Queue,
@@ -68,7 +70,7 @@ def __init__(
6870
):
6971
super().__init__()
7072
self.pool_id = pool_id
71-
self.model_id = model_id
73+
self.model_info = model_info
7274
self.config = config
7375
self.pool_kwargs = pool_kwargs
7476
self.ready_event = ready_event
@@ -121,7 +123,7 @@ def _step(self):
121123

122124
for requests in grouped_requests:
123125
batch_inputs = self._batcher.batch_request(requests).to(self.device)
124-
if self.model_id == "sundial":
126+
if self.model_info.model_type == BuiltInModelType.SUNDIAL.value:
125127
batch_output = self._model.generate(
126128
batch_inputs,
127129
max_new_tokens=requests[0].max_new_tokens,
@@ -135,8 +137,7 @@ def _step(self):
135137
cur_batch_size = request.batch_size
136138
cur_output = batch_output[offset : offset + cur_batch_size]
137139
offset += cur_batch_size
138-
# TODO Here we only considered the case where batchsize=1 in one request. If multi-variable adaptation is required in the future, modifications may be needed here, such as: `cur_output[0]` maybe not true in multi-variable scene
139-
request.write_step_output(cur_output[0].mean(dim=0))
140+
request.write_step_output(cur_output.mean(dim=1))
140141

141142
request.inference_pipeline.post_decode()
142143
if request.is_finished():
@@ -153,7 +154,7 @@ def _step(self):
153154
)
154155
self._waiting_queue.put(request)
155156

156-
elif self.model_id == "timer_xl":
157+
elif self.model_info.model_type == BuiltInModelType.TIMER_XL.value:
157158
batch_output = self._model.generate(
158159
batch_inputs,
159160
max_new_tokens=requests[0].max_new_tokens,
@@ -194,7 +195,9 @@ def run(self):
194195
)
195196
self._model_manager = ModelManager()
196197
self._request_scheduler.device = self.device
197-
self._model = self._model_manager.load_model(self.model_id, {}).to(self.device)
198+
self._model = self._model_manager.load_model(self.model_info.model_id, {}).to(
199+
self.device
200+
)
198201
self.ready_event.set()
199202

200203
activate_daemon = threading.Thread(
@@ -207,10 +210,13 @@ def run(self):
207210
)
208211
self._threads.append(execute_daemon)
209212
execute_daemon.start()
213+
self._logger.info(
214+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_info.model_id} is activated."
215+
)
210216
for thread in self._threads:
211217
thread.join()
212218
self._logger.info(
213-
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_id} exited cleanly."
219+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_info.model_id} exited cleanly."
214220
)
215221

216222
def stop(self):

iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
ScaleActionType,
4141
)
4242
from iotdb.ainode.core.log import Logger
43+
from iotdb.ainode.core.manager.model_manager import ModelManager
44+
from iotdb.ainode.core.model.model_enums import BuiltInModelType
4345
from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig
4446
from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig
4547
from iotdb.ainode.core.util.atmoic_int import AtomicInt
@@ -48,6 +50,7 @@
4850
from iotdb.ainode.core.util.thread_name import ThreadName
4951

5052
logger = Logger()
53+
MODEL_MANAGER = ModelManager()
5154

5255

5356
class PoolController:
@@ -169,7 +172,7 @@ def show_loaded_models(
169172
for model_id, device_map in self._request_pool_map.items():
170173
if device_id in device_map:
171174
pool_group = device_map[device_id]
172-
device_models[model_id] = pool_group.get_pool_count()
175+
device_models[model_id] = pool_group.get_running_pool_count()
173176
result[device_id] = device_models
174177
return result
175178

@@ -191,7 +194,7 @@ def _load_model_task(self, model_id: str, device_id_list: list[str]):
191194
def _load_model_on_device_task(device_id: str):
192195
if not self.has_request_pools(model_id, device_id):
193196
actions = self._pool_scheduler.schedule_load_model_to_device(
194-
model_id, device_id
197+
MODEL_MANAGER.get_model_info(model_id), device_id
195198
)
196199
for action in actions:
197200
if action.action == ScaleActionType.SCALE_UP:
@@ -218,7 +221,7 @@ def _unload_model_task(self, model_id: str, device_id_list: list[str]):
218221
def _unload_model_on_device_task(device_id: str):
219222
if self.has_request_pools(model_id, device_id):
220223
actions = self._pool_scheduler.schedule_unload_model_from_device(
221-
model_id, device_id
224+
MODEL_MANAGER.get_model_info(model_id), device_id
222225
)
223226
for action in actions:
224227
if action.action == ScaleActionType.SCALE_DOWN:
@@ -253,13 +256,19 @@ def _expand_pools_on_device(self, model_id: str, device_id: str, count: int):
253256
def _expand_pool_on_device(*_):
254257
result_queue = mp.Queue()
255258
pool_id = self._new_pool_id.get_and_increment()
256-
if model_id == "sundial":
259+
model_info = MODEL_MANAGER.get_model_info(model_id)
260+
model_type = model_info.model_type
261+
if model_type == BuiltInModelType.SUNDIAL.value:
257262
config = SundialConfig()
258-
elif model_id == "timer_xl":
263+
elif model_id == BuiltInModelType.TIMER_XL.value:
259264
config = TimerConfig()
265+
else:
266+
raise InferenceModelInternalError(
267+
f"Unsupported model type {model_type} for loading model {model_id}"
268+
)
260269
pool = InferenceRequestPool(
261270
pool_id=pool_id,
262-
model_id=model_id,
271+
model_info=model_info,
263272
device=device_id,
264273
config=config,
265274
request_queue=result_queue,

iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def get_pool_ids(self) -> list[int]:
7070
def get_pool_count(self) -> int:
7171
return len(self.pool_group)
7272

73+
def get_running_pool_count(self) -> int:
74+
count = 0
75+
for _, state in self.pool_states.items():
76+
count += 1 if state == PoolState.RUNNING else 0
77+
return count
78+
7379
def dispatch_request(
7480
self, req: InferenceRequest, infer_proxy: InferenceRequestProxy
7581
):

iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/abstract_pool_scheduler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Dict, List
2323

2424
from iotdb.ainode.core.inference.pool_group import PoolGroup
25+
from iotdb.ainode.core.model.model_info import ModelInfo
2526

2627

2728
class ScaleActionType(Enum):
@@ -58,12 +59,12 @@ def schedule(self, model_id: str) -> List[ScaleAction]:
5859

5960
@abstractmethod
6061
def schedule_load_model_to_device(
61-
self, model_id: str, device_id: str
62+
self, model_info: ModelInfo, device_id: str
6263
) -> List[ScaleAction]:
6364
"""
6465
Schedule a series of actions to load the model to the device.
6566
Args:
66-
model_id: The model to be loaded.
67+
model_info: The model to be loaded.
6768
device_id: The device to load the model to.
6869
Returns:
6970
A list of ScaleAction to be performed.
@@ -72,12 +73,12 @@ def schedule_load_model_to_device(
7273

7374
@abstractmethod
7475
def schedule_unload_model_from_device(
75-
self, model_id: str, device_id: str
76+
self, model_info: ModelInfo, device_id: str
7677
) -> List[ScaleAction]:
7778
"""
7879
Schedule a series of actions to unload the model from the device.
7980
Args:
80-
model_id: The model to be unloaded.
81+
model_info: The model to be unloaded.
8182
device_id: The device to unload the model from.
8283
Returns:
8384
A list of ScaleAction to be performed.

iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,26 @@
2828
ScaleActionType,
2929
)
3030
from iotdb.ainode.core.log import Logger
31+
from iotdb.ainode.core.manager.model_manager import ModelManager
3132
from iotdb.ainode.core.manager.utils import (
3233
INFERENCE_EXTRA_MEMORY_RATIO,
3334
INFERENCE_MEMORY_USAGE_RATIO,
3435
MODEL_MEM_USAGE_MAP,
3536
estimate_pool_size,
3637
evaluate_system_resources,
3738
)
38-
from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP
39+
from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP, ModelInfo
3940
from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device
4041

4142
logger = Logger()
4243

44+
MODEL_MANAGER = ModelManager()
45+
4346

4447
def _estimate_shared_pool_size_by_total_mem(
4548
device: torch.device,
46-
existing_model_ids: List[str],
47-
new_model_id: Optional[str] = None,
49+
existing_model_infos: List[ModelInfo],
50+
new_model_info: Optional[ModelInfo] = None,
4851
) -> Dict[str, int]:
4952
"""
5053
Estimate pool counts for (existing_model_ids + new_model_id) by equally
@@ -54,17 +57,15 @@ def _estimate_shared_pool_size_by_total_mem(
5457
mapping {model_id: pool_num}
5558
"""
5659
# Extract unique model IDs
57-
all_models = existing_model_ids + (
58-
[new_model_id] if new_model_id is not None else []
60+
all_models = existing_model_infos + (
61+
[new_model_info] if new_model_info is not None else []
5962
)
6063

6164
# Seize memory usage for each model
6265
mem_usages: Dict[str, float] = {}
63-
for model_id in all_models:
64-
model_info = BUILT_IN_LTSM_MAP.get(model_id)
65-
model_type = model_info.model_type
66-
mem_usages[model_id] = (
67-
MODEL_MEM_USAGE_MAP[model_type] * INFERENCE_EXTRA_MEMORY_RATIO
66+
for model_info in all_models:
67+
mem_usages[model_info.model_id] = (
68+
MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO
6869
)
6970

7071
# Evaluate system resources and get TOTAL memory
@@ -84,14 +85,14 @@ def _estimate_shared_pool_size_by_total_mem(
8485

8586
# Calculate pool allocation for each model
8687
allocation: Dict[str, int] = {}
87-
for model_id in all_models:
88-
pool_num = int(per_model_share // mem_usages[model_id])
88+
for model_info in all_models:
89+
pool_num = int(per_model_share // mem_usages[model_info.model_id])
8990
if pool_num <= 0:
9091
logger.warning(
91-
f"[Inference][Device-{device}] Not enough TOTAL memory to guarantee at least 1 pool for model {model_id}, no pool will be scheduled for this model. "
92-
f"Per-model share={per_model_share / 1024 ** 2:.2f} MB, need>={mem_usages[model_id] / 1024 ** 2:.2f} MB"
92+
f"[Inference][Device-{device}] Not enough TOTAL memory to guarantee at least 1 pool for model {model_info.model_id}, no pool will be scheduled for this model. "
93+
f"Per-model share={per_model_share / 1024 ** 2:.2f} MB, need>={mem_usages[model_info.model_id] / 1024 ** 2:.2f} MB"
9394
)
94-
allocation[model_id] = pool_num
95+
allocation[model_info.model_id] = pool_num
9596
logger.info(
9697
f"[Inference][Device-{device}] Shared pool allocation (by TOTAL memory): {allocation}"
9798
)
@@ -119,39 +120,41 @@ def schedule(self, model_id: str) -> List[ScaleAction]:
119120
return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)]
120121

121122
def schedule_load_model_to_device(
122-
self, model_id: str, device_id: str
123+
self, model_info: ModelInfo, device_id: str
123124
) -> List[ScaleAction]:
124-
existing_model_ids = [
125-
existing_model_id
125+
existing_model_infos = [
126+
MODEL_MANAGER.get_model_info(existing_model_id)
126127
for existing_model_id, pool_group_map in self._request_pool_map.items()
127-
if existing_model_id != model_id and device_id in pool_group_map
128+
if existing_model_id != model_info.model_id and device_id in pool_group_map
128129
]
129130
allocation_result = _estimate_shared_pool_size_by_total_mem(
130131
device=convert_device_id_to_torch_device(device_id),
131-
existing_model_ids=existing_model_ids,
132-
new_model_id=model_id,
132+
existing_model_infos=existing_model_infos,
133+
new_model_info=model_info,
133134
)
134135
return self._convert_allocation_result_to_scale_actions(
135136
allocation_result, device_id
136137
)
137138

138139
def schedule_unload_model_from_device(
139-
self, model_id: str, device_id: str
140+
self, model_info: ModelInfo, device_id: str
140141
) -> List[ScaleAction]:
141-
existing_model_ids = [
142-
existing_model_id
142+
existing_model_infos = [
143+
MODEL_MANAGER.get_model_info(existing_model_id)
143144
for existing_model_id, pool_group_map in self._request_pool_map.items()
144-
if existing_model_id != model_id and device_id in pool_group_map
145+
if existing_model_id != model_info.model_id and device_id in pool_group_map
145146
]
146147
allocation_result = (
147148
_estimate_shared_pool_size_by_total_mem(
148149
device=convert_device_id_to_torch_device(device_id),
149-
existing_model_ids=existing_model_ids,
150-
new_model_id=None,
150+
existing_model_infos=existing_model_infos,
151+
new_model_info=None,
151152
)
152-
if len(existing_model_ids) > 0
153-
else {model_id: 0}
153+
if len(existing_model_infos) > 0
154+
else {model_info.model_id: 0}
154155
)
156+
if len(existing_model_infos) > 0:
157+
allocation_result[model_info.model_id] = 0
155158
return self._convert_allocation_result_to_scale_actions(
156159
allocation_result, device_id
157160
)

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from iotdb.ainode.core.inference.utils import generate_req_id
4848
from iotdb.ainode.core.log import Logger
4949
from iotdb.ainode.core.manager.model_manager import ModelManager
50+
from iotdb.ainode.core.model.model_enums import BuiltInModelType
5051
from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig
5152
from iotdb.ainode.core.model.sundial.modeling_sundial import SundialForPrediction
5253
from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig
@@ -297,9 +298,10 @@ def _run(
297298
data = np_data.view(np_data.dtype.newbyteorder())
298299
# the inputs should be on CPU before passing to the inference request
299300
inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
300-
if model_id == "sundial":
301+
model_type = self._model_manager.get_model_info(model_id).model_type
302+
if model_type == BuiltInModelType.SUNDIAL.value:
301303
inference_pipeline = TimerSundialInferencePipeline(SundialConfig())
302-
elif model_id == "timer_xl":
304+
elif model_type == BuiltInModelType.TIMER_XL.value:
303305
inference_pipeline = TimerXLInferencePipeline(TimerConfig())
304306
else:
305307
raise InferenceModelInternalError(

iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
144144
def register_built_in_model(self, model_info: ModelInfo):
145145
self.model_storage.register_built_in_model(model_info)
146146

147+
def get_model_info(self, model_id: str) -> ModelInfo:
148+
return self.model_storage.get_model_info(model_id)
149+
147150
def update_model_state(self, model_id: str, state: ModelStates):
148151
self.model_storage.update_model_state(model_id, state)
149152

iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,13 @@ def register_built_in_model(self, model_info: ModelInfo):
423423
with self._lock_pool.get_lock(model_info.model_id).write_lock():
424424
self._model_info_map[model_info.model_id] = model_info
425425

426+
def get_model_info(self, model_id: str) -> ModelInfo:
427+
with self._lock_pool.get_lock(model_id).read_lock():
428+
if model_id in self._model_info_map:
429+
return self._model_info_map[model_id]
430+
else:
431+
raise ValueError(f"Model {model_id} does not exist.")
432+
426433
def update_model_state(self, model_id: str, state: ModelStates):
427434
with self._lock_pool.get_lock(model_id).write_lock():
428435
if model_id in self._model_info_map:

0 commit comments

Comments
 (0)