Skip to content

Commit 3471ea7

Browse files
committed
Merge branch 'master' of https://github.com/apache/iotdb into fix-audit-logger
2 parents 2be31d2 + b6c13d7 commit 3471ea7

File tree

21 files changed

+890
-332
lines changed

21 files changed

+890
-332
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT {
4949
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
5050

5151
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
52-
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)";
52+
"SELECT * FROM FORECAST(model_id=>'%s', targets=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)";
5353

5454
@BeforeClass
5555
public static void setUp() throws Exception {

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,21 @@
3838
import java.sql.Statement;
3939

4040
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
41+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
4142

4243
@RunWith(IoTDBTestRunner.class)
4344
@Category({AIClusterIT.class})
4445
public class AINodeForecastIT {
4546

4647
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
47-
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM db.AI) ORDER BY time)";
48+
"SELECT * FROM FORECAST("
49+
+ "model_id=>'%s', "
50+
+ "targets=>(SELECT time, s%d FROM db.AI WHERE time<%d ORDER BY time DESC LIMIT %d) ORDER BY time, "
51+
+ "output_start_time=>%d, "
52+
+ "output_length=>%d, "
53+
+ "output_interval=>%d, "
54+
+ "timecol=>'%s'"
55+
+ ")";
4856

4957
@BeforeClass
5058
public static void setUp() throws Exception {
@@ -55,7 +63,7 @@ public static void setUp() throws Exception {
5563
statement.execute("CREATE DATABASE db");
5664
statement.execute(
5765
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
58-
for (int i = 0; i < 2880; i++) {
66+
for (int i = 0; i < 5760; i++) {
5967
statement.execute(
6068
String.format(
6169
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
@@ -81,18 +89,100 @@ public void forecastTableFunctionTest() throws SQLException {
8189

8290
public void forecastTableFunctionTest(
8391
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
84-
// Invoke call inference for specified models, there should exist result.
92+
// Invoke forecast table function for specified models, there should exist result.
8593
for (int i = 0; i < 4; i++) {
8694
String forecastTableFunctionSQL =
87-
String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), i);
95+
String.format(
96+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
97+
modelInfo.getModelId(),
98+
i,
99+
5760,
100+
2880,
101+
5760,
102+
96,
103+
1,
104+
"time");
88105
try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) {
89106
int count = 0;
90107
while (resultSet.next()) {
91108
count++;
92109
}
93-
// Ensure the call inference return results
110+
// Ensure the forecast sentence return results
94111
Assert.assertTrue(count > 0);
95112
}
96113
}
97114
}
115+
116+
@Test
117+
public void forecastTableFunctionErrorTest() throws SQLException {
118+
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) {
119+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
120+
Statement statement = connection.createStatement()) {
121+
forecastTableFunctionErrorTest(statement, modelInfo);
122+
}
123+
}
124+
}
125+
126+
public void forecastTableFunctionErrorTest(
127+
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
128+
// OUTPUT_START_TIME error
129+
String invalidOutputStartTimeSQL =
130+
String.format(
131+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
132+
modelInfo.getModelId(),
133+
0,
134+
5760,
135+
2880,
136+
5759,
137+
96,
138+
1,
139+
"time");
140+
errorTest(
141+
statement,
142+
invalidOutputStartTimeSQL,
143+
"701: The OUTPUT_START_TIME should be greater than the maximum timestamp of target time series. Expected greater than [5759] but found [5759].");
144+
145+
// OUTPUT_LENGTH error
146+
String invalidOutputLengthSQL =
147+
String.format(
148+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
149+
modelInfo.getModelId(),
150+
0,
151+
5760,
152+
2880,
153+
5760,
154+
0,
155+
1,
156+
"time");
157+
errorTest(statement, invalidOutputLengthSQL, "701: OUTPUT_LENGTH should be greater than 0");
158+
159+
// OUTPUT_INTERVAL error
160+
String invalidOutputIntervalSQL =
161+
String.format(
162+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
163+
modelInfo.getModelId(),
164+
0,
165+
5760,
166+
2880,
167+
5760,
168+
96,
169+
-1,
170+
"time");
171+
errorTest(statement, invalidOutputIntervalSQL, "701: OUTPUT_INTERVAL should be greater than 0");
172+
173+
// TIMECOL error
174+
String invalidTimecolSQL2 =
175+
String.format(
176+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
177+
modelInfo.getModelId(),
178+
0,
179+
5760,
180+
2880,
181+
5760,
182+
96,
183+
1,
184+
"s0");
185+
errorTest(
186+
statement, invalidTimecolSQL2, "701: The type of the column [s0] is not as expected.");
187+
}
98188
}

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
output_length: int = 96,
4343
**infer_kwargs,
4444
):
45-
if inputs.ndim == 1:
45+
while inputs.ndim < 3:
4646
inputs = inputs.unsqueeze(0)
4747

4848
self.req_id = req_id
@@ -54,15 +54,16 @@ def __init__(
5454
)
5555

5656
self.batch_size = inputs.size(0)
57+
self.variable_size = inputs.size(1)
5758
self.state = InferenceRequestState.WAITING
5859
self.cur_step_idx = 0 # Current write position in the output step index
5960
self.assigned_pool_id = -1 # The pool handling this request
6061
self.assigned_device_id = -1 # The device handling this request
6162

6263
# Preallocate output buffer [batch_size, max_new_tokens]
6364
self.output_tensor = torch.zeros(
64-
self.batch_size, output_length, device="cpu"
65-
) # shape: [self.batch_size, max_new_steps]
65+
self.batch_size, self.variable_size, output_length, device="cpu"
66+
) # shape: [batch_size, target_count, predict_length]
6667

6768
def mark_running(self):
6869
self.state = InferenceRequestState.RUNNING
@@ -77,26 +78,26 @@ def is_finished(self) -> bool:
7778
)
7879

7980
def write_step_output(self, step_output: torch.Tensor):
80-
if step_output.ndim == 1:
81+
while step_output.ndim < 3:
8182
step_output = step_output.unsqueeze(0)
8283

83-
batch_size, step_size = step_output.shape
84+
batch_size, variable_size, step_size = step_output.shape
8485
end_idx = self.cur_step_idx + step_size
8586

8687
if end_idx > self.output_length:
87-
self.output_tensor[:, self.cur_step_idx :] = step_output[
88-
:, : self.output_length - self.cur_step_idx
88+
self.output_tensor[:, :, self.cur_step_idx :] = step_output[
89+
:, :, : self.output_length - self.cur_step_idx
8990
]
9091
self.cur_step_idx = self.output_length
9192
else:
92-
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
93+
self.output_tensor[:, :, self.cur_step_idx : end_idx] = step_output
9394
self.cur_step_idx = end_idx
9495

9596
if self.is_finished():
9697
self.mark_finished()
9798

9899
def get_final_output(self) -> torch.Tensor:
99-
return self.output_tensor[:, : self.cur_step_idx]
100+
return self.output_tensor[:, :, : self.cur_step_idx]
100101

101102

102103
class InferenceRequestProxy:

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _step(self):
123123
batch_inputs = self._batcher.batch_request(requests).to(
124124
"cpu"
125125
) # The input data should first load to CPU in current version
126+
batch_inputs = self._inference_pipeline.preprocess(batch_inputs)
126127
if isinstance(self._inference_pipeline, ForecastPipeline):
127128
batch_output = self._inference_pipeline.forecast(
128129
batch_inputs,
@@ -140,7 +141,10 @@ def _step(self):
140141
# more infer kwargs can be added here
141142
)
142143
else:
144+
batch_output = None
143145
self._logger.error("[Inference] Unsupported pipeline type.")
146+
batch_output = self._inference_pipeline.postprocess(batch_output)
147+
144148
offset = 0
145149
for request in requests:
146150
request.output_tensor = request.output_tensor.to(self.device)

iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import torch
2222

23+
from iotdb.ainode.core.exception import InferenceModelInternalException
2324
from iotdb.ainode.core.model.model_loader import load_model
2425

2526

@@ -29,59 +30,75 @@ def __init__(self, model_info, **model_kwargs):
2930
self.device = model_kwargs.get("device", "cpu")
3031
self.model = load_model(model_info, device_map=self.device, **model_kwargs)
3132

32-
def _preprocess(self, inputs):
33+
@abstractmethod
34+
def preprocess(self, inputs):
3335
"""
3436
Preprocess the input before inference, including shape validation and value transformation.
3537
"""
36-
return inputs
38+
raise NotImplementedError("preprocess not implemented")
3739

38-
def _postprocess(self, output: torch.Tensor):
40+
@abstractmethod
41+
def postprocess(self, outputs: torch.Tensor):
3942
"""
4043
Post-process the outputs after the entire inference task.
4144
"""
42-
return output
45+
raise NotImplementedError("postprocess not implemented")
4346

4447

4548
class ForecastPipeline(BasicPipeline):
4649
def __init__(self, model_info, **model_kwargs):
4750
super().__init__(model_info, model_kwargs=model_kwargs)
4851

49-
def _preprocess(self, inputs):
52+
def preprocess(self, inputs):
53+
"""
54+
The inputs should be 3D tensor: [batch_size, target_count, sequence_length].
55+
"""
56+
if len(inputs.shape) != 3:
57+
raise InferenceModelInternalException(
58+
f"[Inference] Input must be: [batch_size, target_count, sequence_length], but receives {inputs.shape}"
59+
)
5060
return inputs
5161

5262
@abstractmethod
5363
def forecast(self, inputs, **infer_kwargs):
5464
pass
5565

56-
def _postprocess(self, output: torch.Tensor):
57-
return output
66+
def postprocess(self, outputs: torch.Tensor):
67+
"""
68+
The outputs should be 3D tensor: [batch_size, target_count, predict_length].
69+
"""
70+
if len(outputs.shape) != 3:
71+
raise InferenceModelInternalException(
72+
f"[Inference] Output must be: [batch_size, target_count, predict_length], but receives {outputs.shape}"
73+
)
74+
return outputs
5875

5976

6077
class ClassificationPipeline(BasicPipeline):
6178
def __init__(self, model_info, **model_kwargs):
6279
super().__init__(model_info, model_kwargs=model_kwargs)
6380

64-
def _preprocess(self, inputs):
81+
def preprocess(self, inputs):
6582
return inputs
6683

6784
@abstractmethod
6885
def classify(self, inputs, **kwargs):
6986
pass
7087

71-
def _postprocess(self, output: torch.Tensor):
72-
return output
88+
def postprocess(self, outputs: torch.Tensor):
89+
return outputs
7390

7491

7592
class ChatPipeline(BasicPipeline):
7693
def __init__(self, model_info, **model_kwargs):
7794
super().__init__(model_info, model_kwargs=model_kwargs)
7895

79-
def _preprocess(self, inputs):
96+
def preprocess(self, inputs):
8097
return inputs
8198

8299
@abstractmethod
83100
def chat(self, inputs, **kwargs):
84101
pass
85102

86-
def _postprocess(self, output: torch.Tensor):
87-
return output
103+
def postprocess(self, outputs: torch.Tensor):
104+
return outputs

0 commit comments

Comments
 (0)