Skip to content

Commit 38c7985

Browse files
committed
seems finished and accelerated
1 parent af5e712 commit 38c7985

File tree

6 files changed

+171
-227
lines changed

6 files changed

+171
-227
lines changed

iotdb-core/ainode/ainode/core/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
AINODE_CONF_FILE_NAME,
3131
AINODE_CONF_GIT_FILE_NAME,
3232
AINODE_CONF_POM_FILE_NAME,
33+
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
3334
AINODE_INFERENCE_RPC_ADDRESS,
3435
AINODE_INFERENCE_RPC_PORT,
3536
AINODE_LOG_DIR,
@@ -55,6 +56,9 @@ def __init__(self):
5556
# Used for connection of DataNode/ConfigNode clients
5657
self._ain_inference_rpc_address: str = AINODE_INFERENCE_RPC_ADDRESS
5758
self._ain_inference_rpc_port: int = AINODE_INFERENCE_RPC_PORT
59+
self._ain_inference_batch_interval_in_ms: int = (
60+
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
61+
)
5862

5963
# log directory
6064
self._ain_logs_dir: str = AINODE_LOG_DIR
@@ -132,6 +136,14 @@ def get_ain_inference_rpc_port(self) -> int:
132136
def set_ain_inference_rpc_port(self, ain_inference_rpc_port: int) -> None:
133137
self._ain_inference_rpc_port = ain_inference_rpc_port
134138

139+
def get_ain_inference_batch_interval_in_ms(self) -> int:
140+
return self._ain_inference_batch_interval_in_ms
141+
142+
def set_ain_inference_batch_interval_in_ms(
143+
self, ain_inference_batch_interval_in_ms: int
144+
) -> None:
145+
self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms
146+
135147
def get_ain_logs_dir(self) -> str:
136148
return self._ain_logs_dir
137149

@@ -273,6 +285,11 @@ def _load_config_from_file(self) -> None:
273285
int(file_configs["ain_inference_rpc_port"])
274286
)
275287

288+
if "ain_inference_batch_interval_in_ms" in config_keys:
289+
self._config.set_ain_inference_batch_interval_in_ms(
290+
int(file_configs["ain_inference_batch_interval_in_ms"])
291+
)
292+
276293
if "ain_models_dir" in config_keys:
277294
self._config.set_ain_models_dir(file_configs["ain_models_dir"])
278295

iotdb-core/ainode/ainode/core/constant.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,30 @@
2929
AINODE_CONF_GIT_FILE_NAME = "git.properties"
3030
AINODE_CONF_POM_FILE_NAME = "pom.properties"
3131
AINODE_SYSTEM_FILE_NAME = "system.properties"
32+
3233
# inference_rpc_address
3334
AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1"
34-
AINODE_INFERENCE_RPC_PORT = 11810
35+
AINODE_INFERENCE_RPC_PORT = 10810
36+
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
37+
3538
# AINode folder structure
3639
AINODE_MODELS_DIR = "data/ainode/models"
3740
AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights" # For built-in models, we only need to store their weights and config.
3841
AINODE_SYSTEM_DIR = "data/ainode/system"
3942
AINODE_LOG_DIR = "logs/ainode"
4043
AINODE_THRIFT_COMPRESSION_ENABLED = False
44+
4145
# use for node management
42-
AINODE_CLUSTER_NAME = "yongzaoCluster"
46+
AINODE_CLUSTER_NAME = "defaultCluster"
4347
AINODE_VERSION_INFO = "UNKNOWN"
4448
AINODE_BUILD_INFO = "UNKNOWN"
4549
AINODE_ROOT_DIR = os.path.dirname(
4650
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
4751
)
52+
4853
# connect IoTDB cluster
4954
AINODE_CLUSTER_INGRESS_ADDRESS = "127.0.0.1"
50-
AINODE_CLUSTER_INGRESS_PORT = 7667
55+
AINODE_CLUSTER_INGRESS_PORT = 6667
5156
AINODE_CLUSTER_INGRESS_USERNAME = "root"
5257
AINODE_CLUSTER_INGRESS_PASSWORD = "root"
5358
AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8"

iotdb-core/ainode/ainode/core/inference/inference_request.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import torch
2222

2323
from ainode.core.inference.strategy.abstract_strategy import AbstractStrategy
24+
from ainode.core.log import Logger
25+
26+
logger = Logger()
2427

2528

2629
class InferenceRequestState:
@@ -32,7 +35,7 @@ class InferenceRequestState:
3235
class InferenceRequest:
3336
def __init__(
3437
self,
35-
req_id: int,
38+
req_id: str,
3639
inputs: torch.Tensor,
3740
strategy: AbstractStrategy,
3841
max_new_tokens: int = 96,
@@ -41,7 +44,7 @@ def __init__(
4144
if inputs.ndim == 1:
4245
inputs = inputs.unsqueeze(0)
4346

44-
self.id = req_id
47+
self.req_id = req_id
4548
self.inputs = inputs
4649
self.infer_kwargs = infer_kwargs
4750
self.strategy = strategy
@@ -59,9 +62,6 @@ def __init__(
5962
self.batch_size, max_new_tokens, device=device
6063
) # shape: [self.batch_size, max_new_steps]
6164

62-
self._lock = threading.Lock()
63-
self._condition = threading.Condition(self._lock)
64-
6565
def mark_running(self):
6666
self.state = InferenceRequestState.RUNNING
6767

@@ -75,34 +75,45 @@ def is_finished(self) -> bool:
7575
)
7676

7777
def write_step_output(self, step_output: torch.Tensor):
78-
with self._lock:
79-
if step_output.ndim == 1:
80-
step_output = step_output.unsqueeze(0)
78+
if step_output.ndim == 1:
79+
step_output = step_output.unsqueeze(0)
8180

82-
batch_size, step_size = step_output.shape
83-
end_idx = self.cur_step_idx + step_size
81+
batch_size, step_size = step_output.shape
82+
end_idx = self.cur_step_idx + step_size
8483

85-
if end_idx > self.max_new_tokens:
86-
self.output_tensor[:, self.cur_step_idx :] = step_output[
87-
:, : self.max_new_tokens - self.cur_step_idx
88-
]
89-
self.cur_step_idx = self.max_new_tokens
90-
else:
91-
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
92-
self.cur_step_idx = end_idx
84+
if end_idx > self.max_new_tokens:
85+
self.output_tensor[:, self.cur_step_idx :] = step_output[
86+
:, : self.max_new_tokens - self.cur_step_idx
87+
]
88+
self.cur_step_idx = self.max_new_tokens
89+
else:
90+
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
91+
self.cur_step_idx = end_idx
9392

94-
if self.is_finished():
95-
self.mark_finished()
93+
if self.is_finished():
94+
self.mark_finished()
9695

9796
def get_final_output(self) -> torch.Tensor:
98-
with self._lock:
99-
return self.output_tensor[:, : self.cur_step_idx]
97+
return self.output_tensor[:, : self.cur_step_idx]
98+
99+
100+
class InferenceRequestProxy:
101+
"""
102+
Wrap the raw request for handling multiprocess processing.
103+
"""
104+
105+
def __init__(self, req_id: str):
106+
self.req_id = req_id
107+
self.result = None
108+
self._lock = threading.Lock()
109+
self._condition = threading.Condition(self._lock)
100110

101-
def notify_completion(self):
111+
def set_result(self, result: Any):
102112
with self._lock:
113+
self.result = result
103114
self._condition.notify_all()
104115

105116
def wait_for_completion(self) -> Any:
106117
with self._lock:
107-
while self.state != InferenceRequestState.FINISHED:
108-
self._condition.wait()
118+
self._condition.wait()
119+
return self.result

0 commit comments

Comments
 (0)