Skip to content

Commit 2434535

Browse files
authored
[AINode] Accelerate memory efficiency of the multiprocessing architecture of inference_manager (apache#15956)
1 parent 67f63e0 commit 2434535

File tree

6 files changed

+31
-16
lines changed

6 files changed

+31
-16
lines changed

iotdb-core/ainode/ainode/core/inference/inference_request.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,8 @@ def __init__(
5959
self.cur_step_idx = 0 # Current write position in the output step index
6060

6161
# Preallocate output buffer [batch_size, max_new_tokens]
62-
device = inputs.device
6362
self.output_tensor = torch.zeros(
64-
self.batch_size, max_new_tokens, device=device
63+
self.batch_size, max_new_tokens, device="cpu"
6564
) # shape: [self.batch_size, max_new_steps]
6665

6766
def mark_running(self):

iotdb-core/ainode/ainode/core/inference/inference_request_pool.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ainode.core.config import AINodeDescriptor
2929
from ainode.core.inference.inference_request import InferenceRequest
3030
from ainode.core.log import Logger
31+
from ainode.core.manager.model_manager import ModelManager
3132

3233
logger = Logger()
3334

@@ -45,18 +46,19 @@ class InferenceRequestPool(mp.Process):
4546
def __init__(
4647
self,
4748
pool_id: int,
48-
model: PreTrainedModel,
49+
model_id: int,
4950
config: PretrainedConfig,
5051
request_queue: mp.Queue,
5152
result_queue: mp.Queue,
5253
**pool_kwargs,
5354
):
5455
super().__init__()
5556
self.pool_id = pool_id
56-
self.model = model
57-
self.device = self.model.device
57+
self.model_id = model_id
5858
self.config = config
5959
self.pool_kwargs = pool_kwargs
60+
self.model = None
61+
self.device = None
6062

6163
# TODO: A scheduler is necessary for better handling following queues
6264
self._threads = []
@@ -97,19 +99,25 @@ def _step(self):
9799
# TODO: We need a batcher to accelerate the concurrent inference
98100
# TODO: Check memory size before executing requests
99101
request: InferenceRequest = self._running_queue.get()
102+
inputs = request.inputs.to(self.device)
100103
output = self.model.generate(
101-
request.inputs,
104+
inputs,
102105
max_new_tokens=request.max_new_tokens,
103106
num_samples=10,
104107
revin=True,
105108
)
109+
request.output_tensor = request.output_tensor.to(
110+
self.device
111+
) # Ensure output tensor is on the same device
106112
request.write_step_output(output[0].mean(dim=0))
107113
request.inference_pipeline.post_decode()
108114
if request.is_finished():
109115
request.inference_pipeline.post_inference()
110116
logger.debug(
111117
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished"
112118
)
119+
# ensure the output tensor is on CPU before sending to result queue
120+
request.output_tensor = request.output_tensor.cpu()
113121
self._finished_queue.put(request)
114122
else:
115123
logger.debug(
@@ -123,6 +131,10 @@ def _requests_execute_loop(self):
123131
self._step()
124132

125133
def run(self):
134+
self._model_manager = ModelManager()
135+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
136+
self.model = self._model_manager.load_model(self.model_id, {}).to(self.device)
137+
126138
activate_daemon = threading.Thread(
127139
target=self._requests_activate_loop, daemon=True
128140
)

iotdb-core/ainode/ainode/core/manager/inference_manager.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,17 @@ def infer(self, full_data, window_interval=None, window_step=None, **_):
135135

136136
class InferenceManager:
137137
ACCELERATE_MODEL_ID = "sundial"
138-
DEFAULT_DEVICE = "cpu"
139-
# DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
138+
# DEFAULT_DEVICE = "cpu"
139+
DEFAULT_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140140
DEFAULT_POOL_SIZE = (
141141
0 # TODO: Remove these parameter by sampling model inference consumption
142142
)
143143
WAITING_INTERVAL_IN_MS = (
144144
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
145145
) # How often to check for requests in the result queue
146146

147-
def __init__(self, model_manager: ModelManager):
148-
self._model_manager = model_manager
147+
def __init__(self):
148+
self._model_manager = ModelManager()
149149
self._result_queue = mp.Queue()
150150
self._result_wrapper_map = {}
151151
self._result_wrapper_lock = threading.RLock()
@@ -165,14 +165,11 @@ def _init_inference_request_pool(self):
165165
"""
166166
self._request_pool_map[self.ACCELERATE_MODEL_ID] = []
167167
for idx in range(self.DEFAULT_POOL_SIZE):
168-
sundial_model = self._model_manager.load_model(
169-
self.ACCELERATE_MODEL_ID, {}
170-
).to(self.DEFAULT_DEVICE)
171168
sundial_config = SundialConfig()
172169
request_queue = mp.Queue()
173170
request_pool = InferenceRequestPool(
174171
pool_id=idx,
175-
model=sundial_model,
172+
model_id=self.ACCELERATE_MODEL_ID,
176173
config=sundial_config,
177174
request_queue=request_queue,
178175
result_queue=self._result_queue,
@@ -223,7 +220,8 @@ def _run(
223220
data = full_data[1][0]
224221
if data.dtype.byteorder not in ("=", "|"):
225222
data = data.byteswap().newbyteorder()
226-
inputs = torch.tensor(data).unsqueeze(0).float().to(self.DEFAULT_DEVICE)
223+
# the inputs should be on CPU before passing to the inference request
224+
inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
227225
infer_req = InferenceRequest(
228226
req_id=_generate_req_id(),
229227
inputs=inputs,

iotdb-core/ainode/ainode/core/manager/model_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from ainode.core.model.model_info import BuiltInModelType, ModelInfo, ModelStates
3030
from ainode.core.model.model_storage import ModelStorage
3131
from ainode.core.rpc.status import get_status
32+
from ainode.core.util.decorator import singleton
3233
from ainode.thrift.ainode.ttypes import (
3334
TDeleteModelReq,
3435
TRegisterModelReq,
@@ -41,6 +42,7 @@
4142
logger = Logger()
4243

4344

45+
@singleton
4446
class ModelManager:
4547
def __init__(self):
4648
self.model_storage = ModelStorage()

iotdb-core/ainode/ainode/core/rpc/handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
4444
def __init__(self, aiNode):
4545
self._aiNode = aiNode
4646
self._model_manager = ModelManager()
47-
self._inference_manager = InferenceManager(model_manager=self._model_manager)
47+
self._inference_manager = InferenceManager()
4848

4949
def stopAINode(self) -> TSStatus:
5050
self._aiNode.stop()

iotdb-core/ainode/ainode/core/script.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import shutil
2020
import sys
2121

22+
import torch.multiprocessing as mp
23+
2224
from ainode.core.ainode import AINode
2325
from ainode.core.config import AINodeDescriptor
2426
from ainode.core.constant import TSStatusCode
@@ -86,6 +88,8 @@ def main():
8688
command = arguments[1]
8789
if command == "start":
8890
try:
91+
mp.set_start_method("spawn", force=True)
92+
logger.info(f"Current multiprocess start method: {mp.get_start_method()}")
8993
logger.info("IoTDB-AINode is starting...")
9094
ai_node = AINode()
9195
ai_node.start()

0 commit comments

Comments
 (0)