Skip to content

Commit a415538

Browse files
CRZbulabulaJackieTien97
authored andcommitted
[AINode] Revert transformer and tokenizer dependencies update (#16394)
(cherry picked from commit 4429187)
1 parent 7fc931d commit a415538

File tree

3 files changed

+125
-41
lines changed

3 files changed

+125
-41
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -56,26 +56,12 @@ public class AINodeInferenceSQLIT {
5656
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
5757
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
5858
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
59-
"insert into root.AI(timestamp,s0,s1,s2,s3) values(1,1.0,2.0,3,4)",
60-
"insert into root.AI(timestamp,s0,s1,s2,s3) values(2,2.0,3.0,4,5)",
61-
"insert into root.AI(timestamp,s0,s1,s2,s3) values(3,3.0,4.0,5,6)",
62-
"insert into root.AI(timestamp,s0,s1,s2,s3) values(4,4.0,5.0,6,7)",
63-
"insert into root.AI(timestamp,s0,s1,s2,s3) values(5,5.0,6.0,7,8)",
64-
"insert into root.AI(timestamp,s0,s1,s2,s3) values(6,6.0,7.0,8,9)",
65-
"insert into root.AI(timestamp,s0,s1,s2,s3) values(7,7.0,8.0,9,10)",
6659
};
6760

6861
static String[] WRITE_SQL_IN_TABLE =
6962
new String[] {
7063
"CREATE DATABASE root",
7164
"CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)",
72-
"insert into root.AI(time,s0,s1,s2,s3) values(1,1.0,2.0,3,4)",
73-
"insert into root.AI(time,s0,s1,s2,s3) values(2,2.0,3.0,4,5)",
74-
"insert into root.AI(time,s0,s1,s2,s3) values(3,3.0,4.0,5,6)",
75-
"insert into root.AI(time,s0,s1,s2,s3) values(4,4.0,5.0,6,7)",
76-
"insert into root.AI(time,s0,s1,s2,s3) values(5,5.0,6.0,7,8)",
77-
"insert into root.AI(time,s0,s1,s2,s3) values(6,6.0,7.0,8,9)",
78-
"insert into root.AI(time,s0,s1,s2,s3) values(7,7.0,8.0,9,10)",
7965
};
8066

8167
@BeforeClass
@@ -84,6 +70,24 @@ public static void setUp() throws Exception {
8470
EnvFactory.getEnv().initClusterEnvironment(1, 1);
8571
prepareData(WRITE_SQL_IN_TREE);
8672
prepareTableData(WRITE_SQL_IN_TABLE);
73+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
74+
Statement statement = connection.createStatement()) {
75+
for (int i = 0; i < 2880; i++) {
76+
statement.execute(
77+
String.format(
78+
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
79+
i, (float) i, (double) i, i, i));
80+
}
81+
}
82+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
83+
Statement statement = connection.createStatement()) {
84+
for (int i = 0; i < 2880; i++) {
85+
statement.execute(
86+
String.format(
87+
"INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
88+
i, (float) i, (double) i, i, i));
89+
}
90+
}
8791
}
8892

8993
@AfterClass
@@ -109,6 +113,29 @@ public void callInferenceTestInTable() throws SQLException {
109113
}
110114

111115
public void callInferenceTest(Statement statement) throws SQLException {
116+
// SQL0: Invoke timer-sundial and timer-xl to inference, the result should success
117+
try (ResultSet resultSet =
118+
statement.executeQuery(
119+
"CALL INFERENCE(sundial, \"select s1 from root.AI\", generateTime=true, predict_length=720)")) {
120+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
121+
checkHeader(resultSetMetaData, "Time,output0");
122+
int count = 0;
123+
while (resultSet.next()) {
124+
count++;
125+
}
126+
assertEquals(720, count);
127+
}
128+
try (ResultSet resultSet =
129+
statement.executeQuery(
130+
"CALL INFERENCE(timer_xl, \"select s2 from root.AI\", generateTime=true, predict_length=256)")) {
131+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
132+
checkHeader(resultSetMetaData, "Time,output0");
133+
int count = 0;
134+
while (resultSet.next()) {
135+
count++;
136+
}
137+
assertEquals(256, count);
138+
}
112139
// SQL1: user-defined model inferences multi-columns with generateTime=true
113140
String sql1 =
114141
"CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", generateTime=true)";
@@ -171,15 +198,15 @@ public void callInferenceTest(Statement statement) throws SQLException {
171198
// assertEquals(3, count);
172199
// }
173200

174-
try (ResultSet resultSet = statement.executeQuery(sql4)) {
175-
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
176-
checkHeader(resultSetMetaData, "Time,output0");
177-
int count = 0;
178-
while (resultSet.next()) {
179-
count++;
180-
}
181-
assertEquals(6, count);
182-
}
201+
// try (ResultSet resultSet = statement.executeQuery(sql4)) {
202+
// ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
203+
// checkHeader(resultSetMetaData, "Time,output0");
204+
// int count = 0;
205+
// while (resultSet.next()) {
206+
// count++;
207+
// }
208+
// assertEquals(6, count);
209+
// }
183210
}
184211

185212
@Test
@@ -219,6 +246,29 @@ public void errorCallInferenceTest(Statement statement) {
219246
public void selectForecastTestInTable() throws SQLException {
220247
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
221248
Statement statement = connection.createStatement()) {
249+
// SQL0: Invoke timer-sundial and timer-xl to forecast, the result should success
250+
try (ResultSet resultSet =
251+
statement.executeQuery(
252+
"SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s1 FROM root.AI) ORDER BY time, output_length=>720)")) {
253+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
254+
checkHeader(resultSetMetaData, "time,s1");
255+
int count = 0;
256+
while (resultSet.next()) {
257+
count++;
258+
}
259+
assertEquals(720, count);
260+
}
261+
try (ResultSet resultSet =
262+
statement.executeQuery(
263+
"SELECT * FROM FORECAST(model_id=>'timer_xl', input=>(SELECT time,s2 FROM root.AI) ORDER BY time, output_length=>256)")) {
264+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
265+
checkHeader(resultSetMetaData, "time,s2");
266+
int count = 0;
267+
while (resultSet.next()) {
268+
count++;
269+
}
270+
assertEquals(256, count);
271+
}
222272
// SQL1: user-defined model inferences multi-columns with generateTime=true
223273
String sql1 =
224274
"SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s0,s1,s2,s3 FROM root.AI) ORDER BY time, output_length=>7)";
@@ -280,15 +330,15 @@ public void selectForecastTestInTable() throws SQLException {
280330
// assertEquals(3, count);
281331
// }
282332

283-
try (ResultSet resultSet = statement.executeQuery(sql4)) {
284-
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
285-
checkHeader(resultSetMetaData, "time,s0");
286-
int count = 0;
287-
while (resultSet.next()) {
288-
count++;
289-
}
290-
assertEquals(6, count);
291-
}
333+
// try (ResultSet resultSet = statement.executeQuery(sql4)) {
334+
// ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
335+
// checkHeader(resultSetMetaData, "time,s0");
336+
// int count = 0;
337+
// while (resultSet.next()) {
338+
// count++;
339+
// }
340+
// assertEquals(6, count);
341+
// }
292342
}
293343
}
294344
}

iotdb-core/ainode/ainode/core/manager/inference_manager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ class TimerXLStrategy(InferenceStrategy):
8484
def infer(self, full_data, predict_length=96, **_):
8585
data = full_data[1][0]
8686
if data.dtype.byteorder not in ("=", "|"):
87-
data = data.byteswap().newbyteorder()
87+
np_data = data.byteswap()
88+
data = np_data.view(np_data.dtype.newbyteorder())
8889
seqs = torch.tensor(data).unsqueeze(0).float()
8990
# TODO: unify model inference input
9091
output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True)
@@ -96,7 +97,8 @@ class SundialStrategy(InferenceStrategy):
9697
def infer(self, full_data, predict_length=96, **_):
9798
data = full_data[1][0]
9899
if data.dtype.byteorder not in ("=", "|"):
99-
data = data.byteswap().newbyteorder()
100+
np_data = data.byteswap()
101+
data = np_data.view(np_data.dtype.newbyteorder())
100102
seqs = torch.tensor(data).unsqueeze(0).float()
101103
# TODO: unify model inference input
102104
output = self.model.generate(
@@ -270,7 +272,7 @@ def _run(
270272
full_data = deserializer(raw)
271273
inference_attrs = extract_attrs(req)
272274

273-
predict_length = int(inference_attrs.get("predict_length", 96))
275+
predict_length = int(inference_attrs.pop("predict_length", 96))
274276
if (
275277
predict_length
276278
> AINodeDescriptor().get_config().get_ain_inference_max_predict_length()
@@ -307,7 +309,8 @@ def _run(
307309
# TODO: TSBlock -> Tensor codes should be unified
308310
data = full_data[1][0]
309311
if data.dtype.byteorder not in ("=", "|"):
310-
data = data.byteswap().newbyteorder()
312+
np_data = data.byteswap()
313+
data = np_data.view(np_data.dtype.newbyteorder())
311314
# the inputs should be on CPU before passing to the inference request
312315
inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
313316
if model_id == "sundial":
@@ -340,7 +343,9 @@ def _run(
340343
model = self._model_manager.load_model(model_id, inference_attrs, accel)
341344
# inference by strategy
342345
strategy = self._get_strategy(model_id, model)
343-
outputs = strategy.infer(full_data, **inference_attrs)
346+
outputs = strategy.infer(
347+
full_data, predict_length=predict_length, **inference_attrs
348+
)
344349

345350
# construct response
346351
status = get_status(TSStatusCode.SUCCESS_STATUS)

iotdb-core/ainode/pyproject.toml

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,39 @@ packages = [
5555
]
5656

5757
[tool.poetry.dependencies]
58-
python = ">=3.9, <3.13"
59-
numpy = "^1.21.4"
60-
pandas = ">=2.2.0"
61-
torch = ">=2.7.1"
58+
python = ">=3.11,<=3.13.5"
59+
60+
# Core scientific stack
61+
numpy = [
62+
{ version = "^2.3.2", python = ">=3.10" },
63+
{ version = "^1.26.4", python = ">=3.9,<3.10" }
64+
]
65+
scipy = [
66+
{ version = "^1.12.0", python = ">=3.10" },
67+
{ version = "^1.11.4", python = ">=3.9,<3.10" }
68+
]
69+
pandas = "^2.3.2"
70+
scikit-learn = [
71+
{ version = "^1.7.1", python = ">=3.10" },
72+
{ version = "^1.5.2", python = ">=3.9,<3.10" }
73+
]
74+
statsmodels = "^0.14.5"
75+
sktime = "0.38.5"
76+
77+
# ---- DL / HF stack ----
78+
torch = ">=2.7.0"
79+
torchmetrics = "^1.8.0"
80+
transformers = "==4.40.1"
81+
tokenizers = ">=0.19,<0.20"
82+
huggingface_hub = "^0.34.4"
83+
safetensors = "^0.6.2"
84+
einops = "^0.8.1"
85+
86+
# ---- Optimizers / utils ----
87+
optuna = "^4.4.0"
88+
psutil = "^7.0.0"
89+
requests = "^2.32.5"
90+
dynaconf = "^3.2.11"
6291
thrift = ">=0.14.0"
6392
dynaconf = "^3.1.11"
6493
requests = "^2.31.0"

0 commit comments

Comments
 (0)