Skip to content

Commit f3e2220

Browse files
authored
[AINode] Adding scheduler to support concurrent inference (#16005)
1 parent 2534344 commit f3e2220

File tree

4 files changed

+247
-42
lines changed

4 files changed

+247
-42
lines changed

iotdb-core/ainode/ainode/core/inference/inference_request_pool.py

Lines changed: 81 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# under the License.
1717
#
1818

19+
import gc
1920
import random
2021
import threading
2122
import time
@@ -27,6 +28,7 @@
2728

2829
from ainode.core.config import AINodeDescriptor
2930
from ainode.core.inference.inference_request import InferenceRequest
31+
from ainode.core.inference.scheduler.basic_scheduler import BasicScheduler
3032
from ainode.core.log import Logger
3133
from ainode.core.manager.model_manager import ModelManager
3234

@@ -61,70 +63,92 @@ def __init__(
6163
self._model_manager = None
6264
self.device = None
6365

64-
# TODO: A scheduler is necessary for better handling following queues
6566
self._threads = []
6667
self._waiting_queue = request_queue # Requests that are waiting to be processed
6768
self._running_queue = mp.Queue() # Requests that are currently being processed
6869
self._finished_queue = result_queue # Requests that are finished
70+
self._scheduler = BasicScheduler(
71+
self._waiting_queue, self._running_queue, self._finished_queue, self.pool_id
72+
)
6973
self._stop_event = mp.Event()
7074

7175
# Fix inference seed
7276
random.seed(self.FIX_SEED)
7377
torch.manual_seed(self.FIX_SEED)
7478
np.random.seed(self.FIX_SEED)
7579

76-
def memory_is_available(self, request):
77-
# need test with several rounds of dummy data
78-
pass
80+
def _warm_up_and_estimate_memory(self):
81+
# TODO: Test per token memory usage, add support for cpu in the future
82+
torch.cuda.empty_cache()
83+
gc.collect()
84+
dummy_input = torch.zeros(
85+
(1, self.config.input_token_len), dtype=torch.float32
86+
).to(self.device)
87+
88+
# force cuda synchronization to avoid any asynchronous memory allocation issues
89+
torch.cuda.reset_peak_memory_stats(self.device)
90+
torch.cuda.synchronize(self.device)
91+
memory_before_warmup = torch.cuda.memory_allocated(self.device)
92+
logger.info(
93+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] Before warm-up, peak memory usage: {memory_before_warmup:.2f} bytes"
94+
)
7995

80-
def _activate_requests(self):
81-
if self._waiting_queue.empty():
82-
return
83-
request: InferenceRequest = self._waiting_queue.get()
84-
# TODO: Check memory size before activating requests
85-
request.inputs = request.inference_pipeline.preprocess_inputs(request.inputs)
86-
request.mark_running()
87-
logger.debug(
88-
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is activated with inputs shape {request.inputs.shape}"
96+
# warm-up
97+
with torch.no_grad():
98+
self.model.generate(dummy_input, max_new_tokens=1)
99+
torch.cuda.synchronize(self.device)
100+
peak_memory_1_token = torch.cuda.max_memory_allocated(self.device)
101+
logger.info(
102+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] Baseline memory usage for 1 token: {peak_memory_1_token:.2f} bytes"
103+
)
104+
logger.info(
105+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] Differentiation : {peak_memory_1_token-memory_before_warmup:.2f} bytes"
89106
)
90-
self._running_queue.put(request)
107+
108+
def _activate_requests(self):
109+
requests = self._scheduler.schedule_activate()
110+
for request in requests:
111+
request.inputs = request.inference_pipeline.preprocess_inputs(
112+
request.inputs
113+
)
114+
request.mark_running()
115+
self._running_queue.put(request)
116+
logger.debug(
117+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is activated with inputs shape {request.inputs.shape}"
118+
)
91119

92120
def _requests_activate_loop(self):
93121
while not self._stop_event.is_set():
94122
time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
95123
self._activate_requests()
96124

97125
def _step(self):
98-
if self._running_queue.empty():
99-
return
126+
requests = self._scheduler.schedule_step()
100127
# TODO: We need a batcher to accelerate the concurrent inference
101-
# TODO: Check memory size before executing requests
102-
request: InferenceRequest = self._running_queue.get()
103-
inputs = request.inputs.to(self.device)
104-
output = self.model.generate(
105-
inputs,
106-
max_new_tokens=request.max_new_tokens,
107-
num_samples=10,
108-
revin=True,
109-
)
110-
request.output_tensor = request.output_tensor.to(
111-
self.device
112-
) # Ensure output tensor is on the same device
113-
request.write_step_output(output[0].mean(dim=0))
114-
request.inference_pipeline.post_decode()
115-
if request.is_finished():
116-
request.inference_pipeline.post_inference()
117-
logger.debug(
118-
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished"
119-
)
120-
# ensure the output tensor is on CPU before sending to result queue
121-
request.output_tensor = request.output_tensor.cpu()
122-
self._finished_queue.put(request)
123-
else:
124-
logger.debug(
125-
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing"
128+
for request in requests:
129+
request.inputs = request.inputs.to(self.device)
130+
output = self.model.generate(
131+
request.inputs,
132+
max_new_tokens=request.max_new_tokens,
133+
num_samples=10,
134+
revin=True,
126135
)
127-
self._waiting_queue.put(request)
136+
request.output_tensor = request.output_tensor.to(self.device)
137+
request.write_step_output(output[0].mean(dim=0))
138+
request.inference_pipeline.post_decode()
139+
if request.is_finished():
140+
request.inference_pipeline.post_inference()
141+
logger.debug(
142+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished"
143+
)
144+
# ensure the output tensor is on CPU before sending to result queue
145+
request.output_tensor = request.output_tensor.cpu()
146+
self._finished_queue.put(request)
147+
else:
148+
logger.debug(
149+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing"
150+
)
151+
self._waiting_queue.put(request)
128152

129153
def _requests_execute_loop(self):
130154
while not self._stop_event.is_set():
@@ -134,8 +158,11 @@ def _requests_execute_loop(self):
134158
def run(self):
135159
self._model_manager = ModelManager()
136160
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
161+
self._scheduler.device = self.device
137162
self.model = self._model_manager.load_model(self.model_id, {}).to(self.device)
138163

164+
# self._warm_up_and_estimate_memory()
165+
139166
activate_daemon = threading.Thread(
140167
target=self._requests_activate_loop, daemon=True
141168
)
@@ -151,3 +178,15 @@ def run(self):
151178

152179
def stop(self):
153180
self._stop_event.set()
181+
logger.info(
182+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] Stopping and releasing resources."
183+
)
184+
try:
185+
del self.model
186+
if "cuda" in str(self.device):
187+
torch.cuda.empty_cache()
188+
gc.collect()
189+
except Exception as e:
190+
logger.warning(
191+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] Failed to clean up: {e}"
192+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
from abc import ABC, abstractmethod
20+
21+
22+
class AbstractScheduler(ABC):
23+
"""
24+
Abstract base class for inference scheduling strategies.
25+
26+
This class defines the high-level interface for scheduling inference requests.
27+
A scheduler is responsible for managing the execution order of inference tasks across different
28+
stages: waiting, running, and finished.
29+
30+
Subclasses should implement specific scheduling logic.
31+
"""
32+
33+
def __init__(self, waiting_queue, running_queue, finished_queue):
34+
"""
35+
Args:
36+
waiting_queue: Queue containing inference requests that are waiting to be executed.
37+
running_queue: Queue containing currently running inference tasks.
38+
finished_queue: Queue containing completed inference tasks.
39+
"""
40+
self.waiting_queue = waiting_queue
41+
self.running_queue = running_queue
42+
self.finished_queue = finished_queue
43+
44+
@abstractmethod
45+
def schedule_activate(self) -> list:
46+
"""
47+
Select one or more inference requests from the waiting queue that are ready to be activated and processed.
48+
49+
Returns:
50+
List: A list of inference requests that will be moved to the running queue.
51+
"""
52+
pass
53+
54+
@abstractmethod
55+
def schedule_step(self) -> list:
56+
"""
57+
Select one or more inference requests from the running queue that are ready to perform the next inference step.
58+
59+
Returns:
60+
List: A list of inference requests that are scheduled to run an inference step.
61+
"""
62+
pass
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
import os
20+
21+
import psutil
22+
import torch
23+
24+
from ainode.core.inference.inference_request import InferenceRequest
25+
from ainode.core.inference.scheduler.abstract_scheduler import AbstractScheduler
26+
from ainode.core.log import Logger
27+
28+
logger = Logger()
29+
30+
31+
class BasicScheduler(AbstractScheduler):
32+
"""
33+
A simple FIFO scheduler that selects requests based on memory availability and activation/step size.
34+
"""
35+
36+
def __init__(
37+
self,
38+
waiting_queue,
39+
running_queue,
40+
finished_queue,
41+
pool_id,
42+
max_memory_bytes=1 << 30,
43+
max_activate_size=10,
44+
max_step_size=10,
45+
):
46+
super().__init__(waiting_queue, running_queue, finished_queue)
47+
self.max_memory_bytes = max_memory_bytes
48+
self.max_activate_size = max_activate_size
49+
self.max_step_size = max_step_size
50+
self.pool_id = pool_id
51+
self.device = None
52+
53+
def memory_is_available(self):
54+
if "cuda" in self.device.type:
55+
used = torch.cuda.memory_allocated(self.device)
56+
reserved = torch.cuda.memory_reserved(self.device)
57+
elif "cpu" in self.device.type:
58+
process = psutil.Process(os.getpid())
59+
used = process.memory_info().rss
60+
reserved = used
61+
else:
62+
used = 0
63+
reserved = 0
64+
logger.warning(
65+
f"[Inference] Unsupported device type: {self.device.type}. Memory checks will not be performed."
66+
)
67+
logger.debug(
68+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] "
69+
f"Memory used: {used} bytes, Max memory: {self.max_memory_bytes} bytes"
70+
)
71+
return used < self.max_memory_bytes
72+
73+
def schedule_activate(self) -> list:
74+
requests = []
75+
while not self.waiting_queue.empty() and len(requests) < self.max_activate_size:
76+
if not self.memory_is_available():
77+
break
78+
requests.append(self.waiting_queue.get())
79+
return requests
80+
81+
def schedule_step(self) -> list:
82+
requests = []
83+
while not self.running_queue.empty() and len(requests) < self.max_step_size:
84+
if not self.memory_is_available():
85+
break
86+
requests.append(self.running_queue.get())
87+
return requests

0 commit comments

Comments
 (0)