Skip to content

Commit e271df7

Browse files
authored
[AINode] Decoupling inference manager into request_manager, pool_manager (#16131)
1 parent a809390 commit e271df7

File tree

8 files changed

+442
-128
lines changed

8 files changed

+442
-128
lines changed

iotdb-core/ainode/ainode/core/inference/inference_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class InferenceRequest:
3838
def __init__(
3939
self,
4040
req_id: str,
41+
model_id: str,
4142
inputs: torch.Tensor,
4243
inference_pipeline: AbstractInferencePipeline,
4344
max_new_tokens: int = 96,
@@ -47,6 +48,7 @@ def __init__(
4748
inputs = inputs.unsqueeze(0)
4849

4950
self.req_id = req_id
51+
self.model_id = model_id
5052
self.inputs = inputs
5153
self.infer_kwargs = infer_kwargs
5254
self.inference_pipeline = inference_pipeline

iotdb-core/ainode/ainode/core/inference/inference_request_pool.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import random
2121
import threading
2222
import time
23+
from enum import Enum
2324

2425
import numpy as np
2526
import torch
@@ -33,6 +34,12 @@
3334
from ainode.core.manager.model_manager import ModelManager
3435

3536

37+
class PoolState(Enum):
38+
INITIALIZING = "INITIALIZING"
39+
RUNNING = "RUNNING"
40+
STOPPING = "STOPPING"
41+
42+
3643
class InferenceRequestPool(mp.Process):
3744
"""
3845
The request pool to handle inference for a specific model.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
from typing import Dict, Tuple
19+
20+
import torch.multiprocessing as mp
21+
22+
from ainode.core.exception import (
23+
InferenceModelInternalError,
24+
)
25+
from ainode.core.inference.inference_request_pool import InferenceRequestPool
26+
from ainode.core.log import Logger
27+
28+
logger = Logger()
29+
30+
31+
class PoolGroup:
32+
"""
33+
A group of inference request pools for a specific model.
34+
"""
35+
36+
def __init__(self, model_id):
37+
self.pool_group: Dict[int, Tuple[InferenceRequestPool, mp.Queue]] = {}
38+
self.model_id = model_id
39+
40+
def get_pool_group(self) -> Dict[int, Tuple[InferenceRequestPool, mp.Queue]]:
41+
return self.pool_group
42+
43+
def add_pool(
44+
self, pool_id: int, request_pool: InferenceRequestPool, request_queue: mp.Queue
45+
):
46+
self.pool_group[pool_id] = (request_pool, request_queue)
47+
48+
def get_pool_ids(self) -> list[int]:
49+
return list(self.pool_group.keys())
50+
51+
def get_request_pool(self, pool_id) -> InferenceRequestPool:
52+
if pool_id not in self.pool_group:
53+
raise InferenceModelInternalError(
54+
f"Pool ID {pool_id} not found for model {self.model_id}"
55+
)
56+
return self.pool_group[pool_id][0]
57+
58+
def get_request_queue(self, pool_id) -> mp.Queue:
59+
if pool_id not in self.pool_group:
60+
raise InferenceModelInternalError(
61+
f"Pool ID {pool_id} not found for model {self.model_id}"
62+
)
63+
return self.pool_group[pool_id][1]
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
from collections import defaultdict
19+
from typing import Dict, Optional
20+
21+
import torch
22+
import torch.multiprocessing as mp
23+
24+
from ainode.core.exception import (
25+
InferenceModelInternalError,
26+
)
27+
from ainode.core.inference.inference_request import InferenceRequest
28+
from ainode.core.inference.inference_request_pool import InferenceRequestPool, PoolState
29+
from ainode.core.inference.inference_request_pool_group import PoolGroup
30+
from ainode.core.log import Logger
31+
32+
logger = Logger()
33+
34+
35+
class PoolController:
36+
"""
37+
A controller for handling inference request pools.
38+
It handles the registration of pools, adding and removing requests,
39+
and managing the state of each pool.
40+
"""
41+
42+
DEFAULT_DEVICE = torch.device("cpu")
43+
# DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44+
45+
def __init__(self):
46+
# structure: {model_id: {pool_id: PoolState}}
47+
self.pool_states: Dict[str, Dict[int, PoolState]] = defaultdict(dict)
48+
# structure: {model_id: PoolGroup}
49+
self._request_pool_map: Dict[str, PoolGroup] = {}
50+
51+
def dispatch_request(self, model_id, req: InferenceRequest):
52+
pool_idx = self._select_pool_by_hash(model_id, req.req_id)
53+
self.add_request(pool_idx, req)
54+
logger.debug(
55+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_idx}][ID-{req.req_id}] Request is queued for inference"
56+
)
57+
58+
def _select_pool_by_hash(self, model_id, req_id) -> int:
59+
pool_ids = self.get_pool_ids(model_id)
60+
if not pool_ids:
61+
raise InferenceModelInternalError(
62+
f"No available pools for model {model_id}"
63+
)
64+
start_idx = hash(req_id) % len(pool_ids)
65+
n = len(pool_ids)
66+
for i in range(n):
67+
pool_id = pool_ids[(start_idx + i) % n]
68+
state = self.get_state(model_id, pool_id)
69+
if state == PoolState.RUNNING:
70+
return pool_id
71+
raise InferenceModelInternalError(
72+
f"No RUNNING pools available for model {model_id}"
73+
)
74+
75+
def register_pool(self, model_id, pool_id, request_pool, request_queue):
76+
self.set_state(model_id, pool_id, PoolState.RUNNING)
77+
self.set_request_pool_map(model_id, pool_id, request_pool, request_queue)
78+
79+
def add_request(self, pool_id, req):
80+
req_q = self.get_request_queue(req.model_id, pool_id)
81+
req_q.put(req)
82+
83+
def remove_request(self, model_id, req_id):
84+
pass
85+
86+
def get_pool_ids(self, model_id) -> list[int]:
87+
return self._request_pool_map[model_id].get_pool_ids()
88+
89+
def has_request_pools(self, model_id) -> bool:
90+
return model_id in self._request_pool_map
91+
92+
def get_request_pool_map(self) -> Dict[str, PoolGroup]:
93+
return self._request_pool_map
94+
95+
def get_request_pools_group(self, model_id) -> Optional[PoolGroup]:
96+
return self._request_pool_map.get(model_id, None)
97+
98+
def get_request_pool(self, model_id, pool_id) -> InferenceRequestPool:
99+
return self._request_pool_map[model_id].get_request_pool(pool_id)
100+
101+
def get_request_queue(self, model_id, pool_id) -> mp.Queue:
102+
return self._request_pool_map[model_id].get_request_queue(pool_id)
103+
104+
def set_request_pool_map(self, model_id, pool_id, request_pool, request_queue):
105+
if model_id not in self._request_pool_map:
106+
self._request_pool_map[model_id] = PoolGroup(model_id)
107+
self._request_pool_map[model_id].add_pool(pool_id, request_pool, request_queue)
108+
109+
def get_state(self, model_id, pool_id) -> PoolState:
110+
return self.pool_states[model_id][pool_id]
111+
112+
def set_state(self, model_id, pool_id, state):
113+
self.pool_states[model_id][pool_id] = state
114+
115+
def get_load(self, model_id, pool_id) -> int:
116+
pass
117+
118+
def shutdown(self):
119+
for model_id, pool_group in self._request_pool_map.items():
120+
for pool_id in pool_group.get_pool_ids():
121+
request_pool = pool_group.get_request_pool(pool_id)
122+
request_queue = pool_group.get_request_queue(pool_id)
123+
request_pool.stop()
124+
while not request_queue.empty():
125+
request_queue.get_nowait()
126+
request_queue.close()
127+
for pool_id in pool_group.get_pool_ids():
128+
request_pool = pool_group.get_request_pool(pool_id)
129+
request_pool.join(timeout=10)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
import threading
19+
20+
import torch
21+
import torch.multiprocessing as mp
22+
23+
from ainode.core.exception import (
24+
InferenceModelInternalError,
25+
)
26+
from ainode.core.inference.inference_request_pool import InferenceRequestPool, PoolState
27+
from ainode.core.inference.pool_controller import PoolController
28+
from ainode.core.log import Logger
29+
from ainode.core.manager.utils import (
30+
_estimate_pool_size,
31+
)
32+
from ainode.core.model.sundial.configuration_sundial import SundialConfig
33+
from ainode.core.model.timerxl.configuration_timer import TimerConfig
34+
from ainode.core.util.decorator import synchronized
35+
36+
logger = Logger()
37+
38+
39+
class PoolScheduler:
40+
"""
41+
A Scheduler to init the request pools.
42+
It initializes the first pool and starts a background thread to expand pools
43+
as needed based on the model_id.
44+
"""
45+
46+
DEFAULT_DEVICE = torch.device("cpu")
47+
# DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48+
49+
def __init__(self, pool_controller: PoolController, result_queue: mp.Queue):
50+
self._pool_controller = pool_controller
51+
self._result_queue = result_queue
52+
53+
@synchronized(threading.Lock())
54+
def first_req_init(self, model_id: str):
55+
if not self._pool_controller.has_request_pools(model_id):
56+
pool_num = _estimate_pool_size(self.DEFAULT_DEVICE, model_id)
57+
if pool_num <= 0:
58+
raise InferenceModelInternalError(
59+
f"Not enough memory to run model {model_id}."
60+
)
61+
# initialize the first pool
62+
self._first_pool_init(model_id)
63+
# start a background thread to expand pools
64+
expand_thread = threading.Thread(
65+
target=self._expand_pools,
66+
args=(model_id, 1, pool_num - 1),
67+
daemon=True,
68+
)
69+
expand_thread.start()
70+
71+
def _first_pool_init(self, model_id: str):
72+
if model_id == "sundial":
73+
config = SundialConfig()
74+
elif model_id == "timer_xl":
75+
config = TimerConfig()
76+
first_queue = mp.Queue()
77+
ready_event = mp.Event()
78+
first_pool = InferenceRequestPool(
79+
pool_id=0,
80+
model_id=model_id,
81+
config=config,
82+
request_queue=first_queue,
83+
result_queue=self._result_queue,
84+
ready_event=ready_event,
85+
)
86+
first_pool.start()
87+
self._pool_controller.set_state(model_id, 0, PoolState.INITIALIZING)
88+
if not ready_event.wait(timeout=30):
89+
logger.error(
90+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] First pool failed to be ready in time"
91+
)
92+
else:
93+
self._pool_controller.register_pool(model_id, 0, first_pool, first_queue)
94+
logger.info(
95+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-0] Initialized inference request pool for model {model_id}"
96+
)
97+
98+
def _expand_pools(self, model_id, start_idx, count):
99+
for idx in range(count):
100+
queue = mp.Queue()
101+
pool_id = start_idx + idx
102+
if model_id == "sundial":
103+
config = SundialConfig()
104+
elif model_id == "timer_xl":
105+
config = TimerConfig()
106+
pool = InferenceRequestPool(
107+
pool_id=pool_id,
108+
model_id=model_id,
109+
config=config,
110+
request_queue=queue,
111+
result_queue=self._result_queue,
112+
ready_event=mp.Event(),
113+
)
114+
pool.start()
115+
self._pool_controller.set_state(model_id, pool_id, PoolState.INITIALIZING)
116+
if not pool.ready_event.wait(timeout=30):
117+
logger.error(
118+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool_id}] Pool failed to be ready in time"
119+
)
120+
continue
121+
else:
122+
self._pool_controller.register_pool(model_id, pool_id, pool, queue)
123+
logger.info(
124+
f"[Inference][Device-{self.DEFAULT_DEVICE}][Pool-{pool.pool_id}] New inference request pool started for model {model_id}"
125+
)

0 commit comments

Comments
 (0)