Skip to content

Commit 9868b41

Browse files
committed
finish
1 parent 340a860 commit 9868b41

File tree

8 files changed

+99
-68
lines changed

8 files changed

+99
-68
lines changed

integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public class AINodeWrapper extends AbstractNodeWrapper {
6060
public static final String CONFIG_PATH = "conf";
6161
public static final String SCRIPT_PATH = "sbin";
6262
public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin";
63-
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights";
63+
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models";
6464

6565
private void replaceAttribute(String[] keys, String[] values, String filePath) {
6666
Properties props = new Properties();

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,12 @@
4040

4141
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
4242
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
43-
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
43+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
4444

4545
@RunWith(IoTDBTestRunner.class)
4646
@Category({AIClusterIT.class})
4747
public class AINodeCallInferenceIT {
4848

49-
private static final String[] WRITE_SQL_IN_TREE =
50-
new String[] {
51-
"CREATE DATABASE root.AI",
52-
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
53-
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
54-
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
55-
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
56-
};
57-
5849
private static final String CALL_INFERENCE_SQL_TEMPLATE =
5950
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)";
6051
private static final int DEFAULT_INPUT_LENGTH = 256;
@@ -64,16 +55,7 @@ public class AINodeCallInferenceIT {
6455
public static void setUp() throws Exception {
6556
// Init 1C1D1A cluster environment
6657
EnvFactory.getEnv().initClusterEnvironment(1, 1);
67-
prepareData(WRITE_SQL_IN_TREE);
68-
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
69-
Statement statement = connection.createStatement()) {
70-
for (int i = 0; i < 2880; i++) {
71-
statement.execute(
72-
String.format(
73-
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
74-
i, (float) i, (double) i, i, i));
75-
}
76-
}
58+
prepareDataInTree();
7759
}
7860

7961
@AfterClass
@@ -91,7 +73,7 @@ public void callInferenceTest() throws SQLException {
9173
}
9274
}
9375

94-
public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
76+
public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
9577
throws SQLException {
9678
// Invoke call inference for specified models, there should exist result.
9779
for (int i = 0; i < 4; i++) {

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
4141
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
42+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
4243

4344
@RunWith(IoTDBTestRunner.class)
4445
@Category({AIClusterIT.class})
@@ -58,18 +59,7 @@ public class AINodeForecastIT {
5859
public static void setUp() throws Exception {
5960
// Init 1C1D1A cluster environment
6061
EnvFactory.getEnv().initClusterEnvironment(1, 1);
61-
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
62-
Statement statement = connection.createStatement()) {
63-
statement.execute("CREATE DATABASE db");
64-
statement.execute(
65-
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
66-
for (int i = 0; i < 5760; i++) {
67-
statement.execute(
68-
String.format(
69-
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
70-
i, (float) i, (double) i, i, i));
71-
}
72-
}
62+
prepareDataInTable();
7363
}
7464

7565
@AfterClass
@@ -87,7 +77,7 @@ public void forecastTableFunctionTest() throws SQLException {
8777
}
8878
}
8979

90-
public void forecastTableFunctionTest(
80+
public static void forecastTableFunctionTest(
9181
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
9282
// Invoke forecast table function for specified models, there should exist result.
9383
for (int i = 0; i < 4; i++) {

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,12 @@
3939
import java.sql.Statement;
4040
import java.util.concurrent.TimeUnit;
4141

42+
import static org.apache.iotdb.ainode.it.AINodeCallInferenceIT.callInferenceTest;
43+
import static org.apache.iotdb.ainode.it.AINodeForecastIT.forecastTableFunctionTest;
4244
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
4345
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
46+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
47+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
4448
import static org.junit.Assert.assertEquals;
4549
import static org.junit.Assert.assertFalse;
4650
import static org.junit.Assert.assertTrue;
@@ -54,54 +58,60 @@ public class AINodeModelManageIT {
5458
public static void setUp() throws Exception {
5559
// Init 1C1D1A cluster environment
5660
EnvFactory.getEnv().initClusterEnvironment(1, 1);
61+
prepareDataInTree();
62+
prepareDataInTable();
5763
}
5864

5965
@AfterClass
6066
public static void tearDown() throws Exception {
6167
EnvFactory.getEnv().cleanClusterEnvironment();
6268
}
6369

64-
// @Test
70+
@Test
6571
public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException {
6672
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
6773
Statement statement = connection.createStatement()) {
68-
userDefinedModelManagementTest(statement);
74+
registerUserDefinedModel(statement);
75+
callInferenceTest(
76+
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
77+
dropUserDefinedModel(statement);
6978
}
7079
}
7180

72-
// @Test
81+
@Test
7382
public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException {
7483
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
7584
Statement statement = connection.createStatement()) {
76-
userDefinedModelManagementTest(statement);
85+
registerUserDefinedModel(statement);
86+
forecastTableFunctionTest(
87+
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
88+
dropUserDefinedModel(statement);
7789
}
7890
}
7991

80-
private void userDefinedModelManagementTest(Statement statement)
92+
private void registerUserDefinedModel(Statement statement)
8193
throws SQLException, InterruptedException {
8294
final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'";
83-
final String registerSql = "create model operationTest using uri \"" + "\"";
84-
final String showSql = "SHOW MODELS operationTest";
85-
final String dropSql = "DROP MODEL operationTest";
86-
95+
final String registerSql = "create model user_chronos using uri \"file:///data/chronos2\"";
96+
final String showSql = "SHOW MODELS user_chronos";
8797
statement.execute(alterConfigSQL);
8898
statement.execute(registerSql);
8999
boolean loading = true;
90-
int count = 0;
91100
for (int retryCnt = 0; retryCnt < 100; retryCnt++) {
92101
try (ResultSet resultSet = statement.executeQuery(showSql)) {
93102
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
94103
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
95104
while (resultSet.next()) {
96105
String modelId = resultSet.getString(1);
106+
String modelType = resultSet.getString(2);
97107
String category = resultSet.getString(3);
98108
String state = resultSet.getString(4);
99-
assertEquals("operationTest", modelId);
100-
assertEquals("USER-DEFINED", category);
101-
if (state.equals("ACTIVE")) {
109+
assertEquals("user_chronos", modelId);
110+
assertEquals("user_defined", category);
111+
assertEquals("custom_t5", modelType);
112+
if (state.equals("active")) {
102113
loading = false;
103-
count++;
104-
} else if (state.equals("LOADING")) {
114+
} else if (state.equals("loading")) {
105115
break;
106116
} else {
107117
fail("Unexpected status of model: " + state);
@@ -114,12 +124,16 @@ private void userDefinedModelManagementTest(Statement statement)
114124
TimeUnit.SECONDS.sleep(1);
115125
}
116126
assertFalse(loading);
117-
assertEquals(1, count);
127+
}
128+
129+
private void dropUserDefinedModel(Statement statement) throws SQLException {
130+
final String showSql = "SHOW MODELS user_chronos";
131+
final String dropSql = "DROP MODEL user_chronos";
118132
statement.execute(dropSql);
119133
try (ResultSet resultSet = statement.executeQuery(showSql)) {
120134
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
121135
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
122-
count = 0;
136+
int count = 0;
123137
while (resultSet.next()) {
124138
count++;
125139
}

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,15 @@
1919

2020
package org.apache.iotdb.ainode.utils;
2121

22+
import org.apache.iotdb.it.env.EnvFactory;
23+
import org.apache.iotdb.itbase.env.BaseEnv;
24+
2225
import com.google.common.collect.ImmutableSet;
2326
import org.junit.Assert;
2427
import org.slf4j.Logger;
2528
import org.slf4j.LoggerFactory;
2629

30+
import java.sql.Connection;
2731
import java.sql.ResultSet;
2832
import java.sql.ResultSetMetaData;
2933
import java.sql.SQLException;
@@ -39,6 +43,7 @@
3943
import java.util.stream.Collectors;
4044
import java.util.stream.Stream;
4145

46+
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
4247
import static org.junit.Assert.assertEquals;
4348
import static org.junit.Assert.fail;
4449

@@ -206,6 +211,45 @@ public static void checkModelNotOnSpecifiedDevice(
206211
fail("Model " + modelId + " is still loaded on device " + device);
207212
}
208213

214+
private static final String[] WRITE_SQL_IN_TREE =
215+
new String[] {
216+
"CREATE DATABASE root.AI",
217+
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
218+
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
219+
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
220+
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
221+
};
222+
223+
/** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */
224+
public static void prepareDataInTree() throws SQLException {
225+
prepareData(WRITE_SQL_IN_TREE);
226+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
227+
Statement statement = connection.createStatement()) {
228+
for (int i = 0; i < 5760; i++) {
229+
statement.execute(
230+
String.format(
231+
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
232+
i, (float) i, (double) i, i, i));
233+
}
234+
}
235+
}
236+
237+
/** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */
238+
public static void prepareDataInTable() throws SQLException {
239+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
240+
Statement statement = connection.createStatement()) {
241+
statement.execute("CREATE DATABASE db");
242+
statement.execute(
243+
"CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
244+
for (int i = 0; i < 5760; i++) {
245+
statement.execute(
246+
String.format(
247+
"INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
248+
i, (float) i, (double) i, i, i));
249+
}
250+
}
251+
}
252+
209253
public static class FakeModelInfo {
210254

211255
private final String modelId;

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
pipeline_cls: str = "",
3232
repo_id: str = "",
3333
auto_map: Optional[Dict] = None,
34-
_transformers_registered: bool = False,
34+
transformers_registered: bool = False,
3535
):
3636
self.model_id = model_id
3737
self.model_type = model_type
@@ -40,7 +40,9 @@ def __init__(
4040
self.pipeline_cls = pipeline_cls
4141
self.repo_id = repo_id
4242
self.auto_map = auto_map # If exists, indicates it's a Transformers model
43-
self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers
43+
self.transformers_registered = (
44+
transformers_registered # Internal flag: whether registered to Transformers
45+
)
4446

4547
def __repr__(self):
4648
return (
@@ -116,7 +118,7 @@ def __repr__(self):
116118
"AutoConfig": "configuration_timer.TimerConfig",
117119
"AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
118120
},
119-
_transformers_registered=True,
121+
transformers_registered=True,
120122
),
121123
"sundial": ModelInfo(
122124
model_id="sundial",
@@ -129,7 +131,7 @@ def __repr__(self):
129131
"AutoConfig": "configuration_sundial.SundialConfig",
130132
"AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
131133
},
132-
_transformers_registered=True,
134+
transformers_registered=True,
133135
),
134136
"chronos2": ModelInfo(
135137
model_id="chronos2",
@@ -139,7 +141,7 @@ def __repr__(self):
139141
pipeline_cls="pipeline_chronos2.Chronos2Pipeline",
140142
repo_id="amazon/chronos-2",
141143
auto_map={
142-
"AutoConfig": "config.Chronos2ForecastingConfig",
144+
"AutoConfig": "config.Chronos2CoreConfig",
143145
"AutoModelForCausalLM": "model.Chronos2Model",
144146
},
145147
),

iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str):
236236
state=ModelStates.ACTIVE,
237237
pipeline_cls=pipeline_cls,
238238
auto_map=auto_map,
239-
_transformers_registered=False, # Lazy registration
239+
transformers_registered=False, # Lazy registration
240240
)
241241
self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info
242242

@@ -287,7 +287,7 @@ def register_model(self, model_id: str, uri: str):
287287
state=ModelStates.ACTIVE,
288288
pipeline_cls=pipeline_cls,
289289
auto_map=auto_map,
290-
_transformers_registered=False, # Register later
290+
transformers_registered=False, # Register later
291291
)
292292
self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info
293293

@@ -296,7 +296,7 @@ def register_model(self, model_id: str, uri: str):
296296
success = self._register_transformers_model(model_info)
297297
if success:
298298
with self._lock_pool.get_lock(model_id).write_lock():
299-
model_info._transformers_registered = True
299+
model_info.transformers_registered = True
300300
else:
301301
with self._lock_pool.get_lock(model_id).write_lock():
302302
model_info.state = ModelStates.INACTIVE
@@ -352,7 +352,7 @@ def _register_other_model(self, model_info: ModelInfo):
352352
f"Registered other type model: {model_info.model_id} ({model_info.model_type})"
353353
)
354354

355-
def ensure_transformers_registered(self, model_id: str) -> ModelInfo:
355+
def ensure_transformers_registered(self, model_id: str) -> ModelInfo | None:
356356
"""
357357
Ensure Transformers model is registered (called for lazy registration)
358358
This method uses locks to ensure thread safety. All check logic is within lock protection.
@@ -369,26 +369,25 @@ def ensure_transformers_registered(self, model_id: str) -> ModelInfo:
369369
break
370370

371371
if not model_info:
372-
logger.warning(f"Model {model_id} does not exist, cannot register")
373372
return None
374373

375374
# If already registered, return directly
376-
if model_info._transformers_registered:
375+
if model_info.transformers_registered:
377376
return model_info
378377

379378
# If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks)
380379
if (
381380
not model_info.auto_map
382381
or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys()
383382
):
384-
model_info._transformers_registered = True
383+
model_info.transformers_registered = True
385384
return model_info
386385

387386
# Execute registration (under lock protection)
388387
try:
389388
success = self._register_transformers_model(model_info)
390389
if success:
391-
model_info._transformers_registered = True
390+
model_info.transformers_registered = True
392391
logger.info(
393392
f"Model {model_id} successfully registered to Transformers"
394393
)
@@ -401,7 +400,7 @@ def ensure_transformers_registered(self, model_id: str) -> ModelInfo:
401400
except Exception as e:
402401
# Ensure state consistency in exception cases
403402
model_info.state = ModelStates.INACTIVE
404-
model_info._transformers_registered = False
403+
model_info.transformers_registered = False
405404
logger.error(
406405
f"Exception occurred while registering model {model_id} to Transformers: {e}"
407406
)

0 commit comments

Comments
 (0)