Skip to content

Commit 393aba2

Browse files
authored
[AINode] Support concurrent inference for Timer-Sundial (#15897)
1 parent 772bab4 commit 393aba2

File tree

12 files changed

+624
-21
lines changed

12 files changed

+624
-21
lines changed

iotdb-core/ainode/ainode/core/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
AINODE_CONF_FILE_NAME,
3131
AINODE_CONF_GIT_FILE_NAME,
3232
AINODE_CONF_POM_FILE_NAME,
33+
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
3334
AINODE_INFERENCE_RPC_ADDRESS,
3435
AINODE_INFERENCE_RPC_PORT,
3536
AINODE_LOG_DIR,
@@ -55,6 +56,9 @@ def __init__(self):
5556
# Used for connection of DataNode/ConfigNode clients
5657
self._ain_inference_rpc_address: str = AINODE_INFERENCE_RPC_ADDRESS
5758
self._ain_inference_rpc_port: int = AINODE_INFERENCE_RPC_PORT
59+
self._ain_inference_batch_interval_in_ms: int = (
60+
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
61+
)
5862

5963
# log directory
6064
self._ain_logs_dir: str = AINODE_LOG_DIR
@@ -132,6 +136,14 @@ def get_ain_inference_rpc_port(self) -> int:
132136
def set_ain_inference_rpc_port(self, ain_inference_rpc_port: int) -> None:
133137
self._ain_inference_rpc_port = ain_inference_rpc_port
134138

139+
def get_ain_inference_batch_interval_in_ms(self) -> int:
140+
return self._ain_inference_batch_interval_in_ms
141+
142+
def set_ain_inference_batch_interval_in_ms(
143+
self, ain_inference_batch_interval_in_ms: int
144+
) -> None:
145+
self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms
146+
135147
def get_ain_logs_dir(self) -> str:
136148
return self._ain_logs_dir
137149

@@ -273,6 +285,11 @@ def _load_config_from_file(self) -> None:
273285
int(file_configs["ain_inference_rpc_port"])
274286
)
275287

288+
if "ain_inference_batch_interval_in_ms" in config_keys:
289+
self._config.set_ain_inference_batch_interval_in_ms(
290+
int(file_configs["ain_inference_batch_interval_in_ms"])
291+
)
292+
276293
if "ain_models_dir" in config_keys:
277294
self._config.set_ain_models_dir(file_configs["ain_models_dir"])
278295

iotdb-core/ainode/ainode/core/constant.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,27 @@
2929
AINODE_CONF_GIT_FILE_NAME = "git.properties"
3030
AINODE_CONF_POM_FILE_NAME = "pom.properties"
3131
AINODE_SYSTEM_FILE_NAME = "system.properties"
32+
3233
# inference_rpc_address
3334
AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1"
3435
AINODE_INFERENCE_RPC_PORT = 10810
36+
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
37+
3538
# AINode folder structure
3639
AINODE_MODELS_DIR = "data/ainode/models"
3740
AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights" # For built-in models, we only need to store their weights and config.
3841
AINODE_SYSTEM_DIR = "data/ainode/system"
3942
AINODE_LOG_DIR = "logs/ainode"
4043
AINODE_THRIFT_COMPRESSION_ENABLED = False
44+
4145
# use for node management
4246
AINODE_CLUSTER_NAME = "defaultCluster"
4347
AINODE_VERSION_INFO = "UNKNOWN"
4448
AINODE_BUILD_INFO = "UNKNOWN"
4549
AINODE_ROOT_DIR = os.path.dirname(
4650
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
4751
)
52+
4853
# connect IoTDB cluster
4954
AINODE_CLUSTER_INGRESS_ADDRESS = "127.0.0.1"
5055
AINODE_CLUSTER_INGRESS_PORT = 6667
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
import threading
19+
from typing import Any
20+
21+
import torch
22+
23+
from ainode.core.inference.strategy.abstract_inference_pipeline import (
24+
AbstractInferencePipeline,
25+
)
26+
from ainode.core.log import Logger
27+
28+
logger = Logger()
29+
30+
31+
class InferenceRequestState:
32+
WAITING = "waiting"
33+
RUNNING = "running"
34+
FINISHED = "finished"
35+
36+
37+
class InferenceRequest:
38+
def __init__(
39+
self,
40+
req_id: str,
41+
inputs: torch.Tensor,
42+
inference_pipeline: AbstractInferencePipeline,
43+
max_new_tokens: int = 96,
44+
**infer_kwargs,
45+
):
46+
if inputs.ndim == 1:
47+
inputs = inputs.unsqueeze(0)
48+
49+
self.req_id = req_id
50+
self.inputs = inputs
51+
self.infer_kwargs = infer_kwargs
52+
self.inference_pipeline = inference_pipeline
53+
self.max_new_tokens = (
54+
max_new_tokens # Number of time series data points to generate
55+
)
56+
57+
self.batch_size = inputs.size(0)
58+
self.state = InferenceRequestState.WAITING
59+
self.cur_step_idx = 0 # Current write position in the output step index
60+
61+
# Preallocate output buffer [batch_size, max_new_tokens]
62+
device = inputs.device
63+
self.output_tensor = torch.zeros(
64+
self.batch_size, max_new_tokens, device=device
65+
) # shape: [self.batch_size, max_new_steps]
66+
67+
def mark_running(self):
68+
self.state = InferenceRequestState.RUNNING
69+
70+
def mark_finished(self):
71+
self.state = InferenceRequestState.FINISHED
72+
73+
def is_finished(self) -> bool:
74+
return (
75+
self.state == InferenceRequestState.FINISHED
76+
or self.cur_step_idx >= self.max_new_tokens
77+
)
78+
79+
def write_step_output(self, step_output: torch.Tensor):
80+
if step_output.ndim == 1:
81+
step_output = step_output.unsqueeze(0)
82+
83+
batch_size, step_size = step_output.shape
84+
end_idx = self.cur_step_idx + step_size
85+
86+
if end_idx > self.max_new_tokens:
87+
self.output_tensor[:, self.cur_step_idx :] = step_output[
88+
:, : self.max_new_tokens - self.cur_step_idx
89+
]
90+
self.cur_step_idx = self.max_new_tokens
91+
else:
92+
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
93+
self.cur_step_idx = end_idx
94+
95+
if self.is_finished():
96+
self.mark_finished()
97+
98+
def get_final_output(self) -> torch.Tensor:
99+
return self.output_tensor[:, : self.cur_step_idx]
100+
101+
102+
class InferenceRequestProxy:
103+
"""
104+
Wrap the raw request for handling multiprocess processing.
105+
"""
106+
107+
def __init__(self, req_id: str):
108+
self.req_id = req_id
109+
self.result = None
110+
self._lock = threading.Lock()
111+
self._condition = threading.Condition(self._lock)
112+
113+
def set_result(self, result: Any):
114+
with self._lock:
115+
self.result = result
116+
self._condition.notify_all()
117+
118+
def wait_for_completion(self) -> Any:
119+
with self._lock:
120+
self._condition.wait()
121+
return self.result
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
import random
20+
import threading
21+
import time
22+
23+
import numpy as np
24+
import torch
25+
import torch.multiprocessing as mp
26+
from transformers import PretrainedConfig, PreTrainedModel
27+
28+
from ainode.core.config import AINodeDescriptor
29+
from ainode.core.inference.inference_request import InferenceRequest
30+
from ainode.core.log import Logger
31+
32+
logger = Logger()
33+
34+
35+
class InferenceRequestPool(mp.Process):
36+
"""
37+
The request pool to handle inference for a specific model.
38+
"""
39+
40+
FIX_SEED = 2021
41+
WAITING_INTERVAL_IN_MS = (
42+
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
43+
) # How often to check for requests in the waiting/running queue
44+
45+
def __init__(
46+
self,
47+
pool_id: int,
48+
model: PreTrainedModel,
49+
config: PretrainedConfig,
50+
request_queue: mp.Queue,
51+
result_queue: mp.Queue,
52+
**pool_kwargs,
53+
):
54+
super().__init__()
55+
self.pool_id = pool_id
56+
self.model = model
57+
self.device = self.model.device
58+
self.config = config
59+
self.pool_kwargs = pool_kwargs
60+
61+
# TODO: A scheduler is necessary for better handling following queues
62+
self._threads = []
63+
self._waiting_queue = request_queue # Requests that are waiting to be processed
64+
self._running_queue = mp.Queue() # Requests that are currently being processed
65+
self._finished_queue = result_queue # Requests that are finished
66+
self._stop_event = mp.Event()
67+
68+
# Fix inference seed
69+
random.seed(self.FIX_SEED)
70+
torch.manual_seed(self.FIX_SEED)
71+
np.random.seed(self.FIX_SEED)
72+
73+
def memory_is_available(self, request):
74+
# need test with several rounds of dummy data
75+
pass
76+
77+
def _activate_requests(self):
78+
if self._waiting_queue.empty():
79+
return
80+
request: InferenceRequest = self._waiting_queue.get()
81+
# TODO: Check memory size before activating requests
82+
request.inputs = request.inference_pipeline.preprocess_inputs(request.inputs)
83+
request.mark_running()
84+
logger.debug(
85+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is activated with inputs shape {request.inputs.shape}"
86+
)
87+
self._running_queue.put(request)
88+
89+
def _requests_activate_loop(self):
90+
while not self._stop_event.is_set():
91+
time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
92+
self._activate_requests()
93+
94+
def _step(self):
95+
if self._running_queue.empty():
96+
return
97+
# TODO: We need a batcher to accelerate the concurrent inference
98+
# TODO: Check memory size before executing requests
99+
request: InferenceRequest = self._running_queue.get()
100+
output = self.model.generate(
101+
request.inputs,
102+
max_new_tokens=request.max_new_tokens,
103+
num_samples=10,
104+
revin=True,
105+
)
106+
request.write_step_output(output[0].mean(dim=0))
107+
request.inference_pipeline.post_decode()
108+
if request.is_finished():
109+
request.inference_pipeline.post_inference()
110+
logger.debug(
111+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished"
112+
)
113+
self._finished_queue.put(request)
114+
else:
115+
logger.debug(
116+
f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing"
117+
)
118+
self._waiting_queue.put(request)
119+
120+
def _requests_execute_loop(self):
121+
while not self._stop_event.is_set():
122+
time.sleep(self.WAITING_INTERVAL_IN_MS / 1000)
123+
self._step()
124+
125+
def run(self):
126+
activate_daemon = threading.Thread(
127+
target=self._requests_activate_loop, daemon=True
128+
)
129+
self._threads.append(activate_daemon)
130+
activate_daemon.start()
131+
execute_daemon = threading.Thread(
132+
target=self._requests_execute_loop, daemon=True
133+
)
134+
self._threads.append(execute_daemon)
135+
execute_daemon.start()
136+
for thread in self._threads:
137+
thread.join()
138+
139+
def stop(self):
140+
self._stop_event.set()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#

0 commit comments

Comments
 (0)