Skip to content

Commit 52002e8

Browse files
authored
[AINode] Append model management IT (#16938)
1 parent 7722963 commit 52002e8

File tree

9 files changed

+150
-96
lines changed

9 files changed

+150
-96
lines changed

integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public class AINodeWrapper extends AbstractNodeWrapper {
6060
public static final String CONFIG_PATH = "conf";
6161
public static final String SCRIPT_PATH = "sbin";
6262
public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin";
63-
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights";
63+
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models";
6464

6565
private void replaceAttribute(String[] keys, String[] values, String filePath) {
6666
Properties props = new Properties();

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,12 @@
4040

4141
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
4242
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
43-
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
43+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
4444

4545
@RunWith(IoTDBTestRunner.class)
4646
@Category({AIClusterIT.class})
4747
public class AINodeCallInferenceIT {
4848

49-
private static final String[] WRITE_SQL_IN_TREE =
50-
new String[] {
51-
"CREATE DATABASE root.AI",
52-
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
53-
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
54-
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
55-
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
56-
};
57-
5849
private static final String CALL_INFERENCE_SQL_TEMPLATE =
5950
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)";
6051
private static final int DEFAULT_INPUT_LENGTH = 256;
@@ -64,16 +55,7 @@ public class AINodeCallInferenceIT {
6455
public static void setUp() throws Exception {
6556
// Init 1C1D1A cluster environment
6657
EnvFactory.getEnv().initClusterEnvironment(1, 1);
67-
prepareData(WRITE_SQL_IN_TREE);
68-
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
69-
Statement statement = connection.createStatement()) {
70-
for (int i = 0; i < 2880; i++) {
71-
statement.execute(
72-
String.format(
73-
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
74-
i, (float) i, (double) i, i, i));
75-
}
76-
}
58+
prepareDataInTree();
7759
}
7860

7961
@AfterClass
@@ -91,7 +73,7 @@ public void callInferenceTest() throws SQLException {
9173
}
9274
}
9375

94-
public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
76+
public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
9577
throws SQLException {
9678
// Invoke call inference for specified models, there should exist result.
9779
for (int i = 0; i < 4; i++) {

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
4141
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
42+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
4243

4344
@RunWith(IoTDBTestRunner.class)
4445
@Category({AIClusterIT.class})
@@ -58,18 +59,7 @@ public class AINodeForecastIT {
5859
public static void setUp() throws Exception {
5960
// Init 1C1D1A cluster environment
6061
EnvFactory.getEnv().initClusterEnvironment(1, 1);
61-
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
62-
Statement statement = connection.createStatement()) {
63-
statement.execute("CREATE DATABASE db");
64-
statement.execute(
65-
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
66-
for (int i = 0; i < 5760; i++) {
67-
statement.execute(
68-
String.format(
69-
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
70-
i, (float) i, (double) i, i, i));
71-
}
72-
}
62+
prepareDataInTable();
7363
}
7464

7565
@AfterClass
@@ -87,7 +77,7 @@ public void forecastTableFunctionTest() throws SQLException {
8777
}
8878
}
8979

90-
public void forecastTableFunctionTest(
80+
public static void forecastTableFunctionTest(
9181
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
9282
// Invoke forecast table function for specified models, there should exist result.
9383
for (int i = 0; i < 4; i++) {

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@
3939
import java.sql.Statement;
4040
import java.util.concurrent.TimeUnit;
4141

42+
import static org.apache.iotdb.ainode.it.AINodeCallInferenceIT.callInferenceTest;
43+
import static org.apache.iotdb.ainode.it.AINodeForecastIT.forecastTableFunctionTest;
4244
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
4345
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
46+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
47+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
4448
import static org.junit.Assert.assertEquals;
4549
import static org.junit.Assert.assertFalse;
4650
import static org.junit.Assert.assertTrue;
@@ -54,54 +58,70 @@ public class AINodeModelManageIT {
5458
public static void setUp() throws Exception {
5559
// Init 1C1D1A cluster environment
5660
EnvFactory.getEnv().initClusterEnvironment(1, 1);
61+
prepareDataInTree();
62+
prepareDataInTable();
5763
}
5864

5965
@AfterClass
6066
public static void tearDown() throws Exception {
6167
EnvFactory.getEnv().cleanClusterEnvironment();
6268
}
6369

64-
// @Test
70+
@Test
6571
public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException {
6672
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
6773
Statement statement = connection.createStatement()) {
68-
userDefinedModelManagementTest(statement);
74+
registerUserDefinedModel(statement);
75+
callInferenceTest(
76+
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
77+
dropUserDefinedModel(statement);
78+
errorTest(
79+
statement,
80+
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
81+
"1505: 't5' is already used by a Transformers config, pick another name.");
82+
statement.execute("drop model origin_chronos");
6983
}
7084
}
7185

72-
// @Test
86+
@Test
7387
public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException {
7488
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
7589
Statement statement = connection.createStatement()) {
76-
userDefinedModelManagementTest(statement);
90+
registerUserDefinedModel(statement);
91+
forecastTableFunctionTest(
92+
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
93+
dropUserDefinedModel(statement);
94+
errorTest(
95+
statement,
96+
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
97+
"1505: 't5' is already used by a Transformers config, pick another name.");
98+
statement.execute("drop model origin_chronos");
7799
}
78100
}
79101

80-
private void userDefinedModelManagementTest(Statement statement)
102+
private void registerUserDefinedModel(Statement statement)
81103
throws SQLException, InterruptedException {
82104
final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'";
83-
final String registerSql = "create model operationTest using uri \"" + "\"";
84-
final String showSql = "SHOW MODELS operationTest";
85-
final String dropSql = "DROP MODEL operationTest";
86-
105+
final String registerSql = "create model user_chronos using uri \"file:///data/chronos2\"";
106+
final String showSql = "SHOW MODELS user_chronos";
87107
statement.execute(alterConfigSQL);
88108
statement.execute(registerSql);
89109
boolean loading = true;
90-
int count = 0;
91110
for (int retryCnt = 0; retryCnt < 100; retryCnt++) {
92111
try (ResultSet resultSet = statement.executeQuery(showSql)) {
93112
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
94113
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
95114
while (resultSet.next()) {
96115
String modelId = resultSet.getString(1);
116+
String modelType = resultSet.getString(2);
97117
String category = resultSet.getString(3);
98118
String state = resultSet.getString(4);
99-
assertEquals("operationTest", modelId);
100-
assertEquals("USER-DEFINED", category);
101-
if (state.equals("ACTIVE")) {
119+
assertEquals("user_chronos", modelId);
120+
assertEquals("custom_t5", modelType);
121+
assertEquals("user_defined", category);
122+
if (state.equals("active")) {
102123
loading = false;
103-
count++;
104-
} else if (state.equals("LOADING")) {
124+
} else if (state.equals("loading")) {
105125
break;
106126
} else {
107127
fail("Unexpected status of model: " + state);
@@ -114,12 +134,16 @@ private void userDefinedModelManagementTest(Statement statement)
114134
TimeUnit.SECONDS.sleep(1);
115135
}
116136
assertFalse(loading);
117-
assertEquals(1, count);
137+
}
138+
139+
private void dropUserDefinedModel(Statement statement) throws SQLException {
140+
final String showSql = "SHOW MODELS user_chronos";
141+
final String dropSql = "DROP MODEL user_chronos";
118142
statement.execute(dropSql);
119143
try (ResultSet resultSet = statement.executeQuery(showSql)) {
120144
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
121145
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
122-
count = 0;
146+
int count = 0;
123147
while (resultSet.next()) {
124148
count++;
125149
}

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,15 @@
1919

2020
package org.apache.iotdb.ainode.utils;
2121

22+
import org.apache.iotdb.it.env.EnvFactory;
23+
import org.apache.iotdb.itbase.env.BaseEnv;
24+
2225
import com.google.common.collect.ImmutableSet;
2326
import org.junit.Assert;
2427
import org.slf4j.Logger;
2528
import org.slf4j.LoggerFactory;
2629

30+
import java.sql.Connection;
2731
import java.sql.ResultSet;
2832
import java.sql.ResultSetMetaData;
2933
import java.sql.SQLException;
@@ -39,6 +43,7 @@
3943
import java.util.stream.Collectors;
4044
import java.util.stream.Stream;
4145

46+
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
4247
import static org.junit.Assert.assertEquals;
4348
import static org.junit.Assert.fail;
4449

@@ -206,6 +211,45 @@ public static void checkModelNotOnSpecifiedDevice(
206211
fail("Model " + modelId + " is still loaded on device " + device);
207212
}
208213

214+
private static final String[] WRITE_SQL_IN_TREE =
215+
new String[] {
216+
"CREATE DATABASE root.AI",
217+
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
218+
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
219+
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
220+
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
221+
};
222+
223+
/** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */
224+
public static void prepareDataInTree() throws SQLException {
225+
prepareData(WRITE_SQL_IN_TREE);
226+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
227+
Statement statement = connection.createStatement()) {
228+
for (int i = 0; i < 5760; i++) {
229+
statement.execute(
230+
String.format(
231+
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
232+
i, (float) i, (double) i, i, i));
233+
}
234+
}
235+
}
236+
237+
/** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */
238+
public static void prepareDataInTable() throws SQLException {
239+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
240+
Statement statement = connection.createStatement()) {
241+
statement.execute("CREATE DATABASE db");
242+
statement.execute(
243+
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
244+
for (int i = 0; i < 5760; i++) {
245+
statement.execute(
246+
String.format(
247+
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
248+
i, (float) i, (double) i, i, i));
249+
}
250+
}
251+
}
252+
209253
public static class FakeModelInfo {
210254

211255
private final String modelId;

iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,13 @@ def register_model(
6161
return TRegisterModelResp(
6262
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
6363
)
64+
except Exception as e:
65+
# Catch-all for other exceptions (mainly from transformers implementation)
66+
return TRegisterModelResp(
67+
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
68+
)
6469

6570
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
66-
self._refresh()
6771
return self._model_storage.show_models(req)
6872

6973
def delete_model(self, req: TDeleteModelReq) -> TSStatus:

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
pipeline_cls: str = "",
3232
repo_id: str = "",
3333
auto_map: Optional[Dict] = None,
34-
_transformers_registered: bool = False,
34+
transformers_registered: bool = False,
3535
):
3636
self.model_id = model_id
3737
self.model_type = model_type
@@ -40,7 +40,9 @@ def __init__(
4040
self.pipeline_cls = pipeline_cls
4141
self.repo_id = repo_id
4242
self.auto_map = auto_map # If exists, indicates it's a Transformers model
43-
self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers
43+
self.transformers_registered = (
44+
transformers_registered # Internal flag: whether registered to Transformers
45+
)
4446

4547
def __repr__(self):
4648
return (
@@ -116,7 +118,7 @@ def __repr__(self):
116118
"AutoConfig": "configuration_timer.TimerConfig",
117119
"AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
118120
},
119-
_transformers_registered=True,
121+
transformers_registered=True,
120122
),
121123
"sundial": ModelInfo(
122124
model_id="sundial",
@@ -129,7 +131,7 @@ def __repr__(self):
129131
"AutoConfig": "configuration_sundial.SundialConfig",
130132
"AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
131133
},
132-
_transformers_registered=True,
134+
transformers_registered=True,
133135
),
134136
"chronos2": ModelInfo(
135137
model_id="chronos2",
@@ -139,7 +141,7 @@ def __repr__(self):
139141
pipeline_cls="pipeline_chronos2.Chronos2Pipeline",
140142
repo_id="amazon/chronos-2",
141143
auto_map={
142-
"AutoConfig": "config.Chronos2ForecastingConfig",
144+
"AutoConfig": "config.Chronos2CoreConfig",
143145
"AutoModelForCausalLM": "model.Chronos2Model",
144146
},
145147
),

0 commit comments

Comments
 (0)