Skip to content

Commit b1b90cd

Browse files
committed
stash changes
1 parent 9042a0d commit b1b90cd

File tree

11 files changed

+674
-21
lines changed

11 files changed

+674
-21
lines changed

iotdb-core/ainode/ainode/core/constant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,23 @@
3131
AINODE_SYSTEM_FILE_NAME = "system.properties"
3232
# inference_rpc_address
3333
AINODE_INFERENCE_RPC_ADDRESS = "127.0.0.1"
34-
AINODE_INFERENCE_RPC_PORT = 10810
34+
AINODE_INFERENCE_RPC_PORT = 11810
3535
# AINode folder structure
3636
AINODE_MODELS_DIR = "data/ainode/models"
3737
AINODE_BUILTIN_MODELS_DIR = "data/ainode/models/weights" # For built-in models, we only need to store their weights and config.
3838
AINODE_SYSTEM_DIR = "data/ainode/system"
3939
AINODE_LOG_DIR = "logs/ainode"
4040
AINODE_THRIFT_COMPRESSION_ENABLED = False
4141
# use for node management
42-
AINODE_CLUSTER_NAME = "defaultCluster"
42+
AINODE_CLUSTER_NAME = "yongzaoCluster"
4343
AINODE_VERSION_INFO = "UNKNOWN"
4444
AINODE_BUILD_INFO = "UNKNOWN"
4545
AINODE_ROOT_DIR = os.path.dirname(
4646
os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
4747
)
4848
# connect IoTDB cluster
4949
AINODE_CLUSTER_INGRESS_ADDRESS = "127.0.0.1"
50-
AINODE_CLUSTER_INGRESS_PORT = 6667
50+
AINODE_CLUSTER_INGRESS_PORT = 7667
5151
AINODE_CLUSTER_INGRESS_USERNAME = "root"
5252
AINODE_CLUSTER_INGRESS_PASSWORD = "root"
5353
AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8"
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
import threading
19+
from typing import Any
20+
21+
import torch
22+
23+
from ainode.core.inference.strategy.abstract_strategy import AbstractStrategy
24+
25+
26+
class InferenceRequestState:
27+
WAITING = "waiting"
28+
RUNNING = "running"
29+
FINISHED = "finished"
30+
31+
32+
class InferenceRequest:
33+
def __init__(
34+
self,
35+
req_id: int,
36+
inputs: torch.Tensor,
37+
strategy: AbstractStrategy,
38+
max_new_tokens: int = 96,
39+
**infer_kwargs,
40+
):
41+
if inputs.ndim == 1:
42+
inputs = inputs.unsqueeze(0)
43+
44+
self.id = req_id
45+
self.inputs = inputs
46+
self.infer_kwargs = infer_kwargs
47+
self.strategy = strategy
48+
self.max_new_tokens = (
49+
max_new_tokens # Number of time series data points to generate
50+
)
51+
52+
self.batch_size = inputs.size(0)
53+
self.state = InferenceRequestState.WAITING
54+
self.cur_step_idx = 0 # Current write position in the output step index
55+
56+
# Preallocate output buffer [batch_size, max_new_tokens]
57+
device = inputs.device
58+
self.output_tensor = torch.zeros(
59+
self.batch_size, max_new_tokens, device=device
60+
) # shape: [self.batch_size, max_new_steps]
61+
62+
self._lock = threading.Lock()
63+
self._condition = threading.Condition(self._lock)
64+
65+
def mark_running(self):
66+
self.state = InferenceRequestState.RUNNING
67+
68+
def mark_finished(self):
69+
self.state = InferenceRequestState.FINISHED
70+
71+
def is_finished(self) -> bool:
72+
return (
73+
self.state == InferenceRequestState.FINISHED
74+
or self.cur_step_idx >= self.max_new_tokens
75+
)
76+
77+
def write_step_output(self, step_output: torch.Tensor):
78+
with self._lock:
79+
if step_output.ndim == 1:
80+
step_output = step_output.unsqueeze(0)
81+
82+
batch_size, step_size = step_output.shape
83+
end_idx = self.cur_step_idx + step_size
84+
85+
if end_idx > self.max_new_tokens:
86+
self.output_tensor[:, self.cur_step_idx :] = step_output[
87+
:, : self.max_new_tokens - self.cur_step_idx
88+
]
89+
self.cur_step_idx = self.max_new_tokens
90+
else:
91+
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
92+
self.cur_step_idx = end_idx
93+
94+
if self.is_finished():
95+
self.mark_finished()
96+
97+
def get_final_output(self) -> torch.Tensor:
98+
with self._lock:
99+
return self.output_tensor[:, : self.cur_step_idx]
100+
101+
def notify_completion(self):
102+
with self._lock:
103+
self._condition.notify_all()
104+
105+
def wait_for_completion(self) -> Any:
106+
with self._lock:
107+
while self.state != InferenceRequestState.FINISHED:
108+
self._condition.wait()

0 commit comments

Comments
 (0)