Skip to content

Commit d7898c4

Browse files
CRZbulabulaRkGrit
andauthored
[AINode] Refactoring of Model Storage, Loading, and Inference Pipeline (#16819)
Co-authored-by: RkGrit <[email protected]> Co-authored-by: Gewu <[email protected]>
1 parent a899c48 commit d7898c4

File tree

115 files changed

+3205
-6963
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

115 files changed

+3205
-6963
lines changed

.github/workflows/cluster-it-1c1d1a.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ jobs:
4141

4242
steps:
4343
- uses: actions/checkout@v4
44-
- name: Build AINode
45-
shell: bash
46-
run: mvn clean package -DskipTests -P with-ainode
4744
- name: IT Test
4845
shell: bash
4946
run: |

integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public class AINodeWrapper extends AbstractNodeWrapper {
5959
private static final String PROPERTIES_FILE = "iotdb-ainode.properties";
6060
public static final String CONFIG_PATH = "conf";
6161
public static final String SCRIPT_PATH = "sbin";
62-
public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/weights";
62+
public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin";
6363
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights";
6464

6565
private void replaceAttribute(String[] keys, String[] values, String filePath) {
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.ainode.it;
21+
22+
import org.apache.iotdb.ainode.utils.AINodeTestUtils;
23+
import org.apache.iotdb.it.env.EnvFactory;
24+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
25+
import org.apache.iotdb.itbase.category.AIClusterIT;
26+
import org.apache.iotdb.itbase.env.BaseEnv;
27+
28+
import org.junit.AfterClass;
29+
import org.junit.Assert;
30+
import org.junit.BeforeClass;
31+
import org.junit.Test;
32+
import org.junit.experimental.categories.Category;
33+
import org.junit.runner.RunWith;
34+
35+
import java.sql.Connection;
36+
import java.sql.ResultSet;
37+
import java.sql.ResultSetMetaData;
38+
import java.sql.SQLException;
39+
import java.sql.Statement;
40+
41+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
42+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
43+
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
44+
45+
@RunWith(IoTDBTestRunner.class)
46+
@Category({AIClusterIT.class})
47+
public class AINodeCallInferenceIT {
48+
49+
private static final String[] WRITE_SQL_IN_TREE =
50+
new String[] {
51+
"CREATE DATABASE root.AI",
52+
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
53+
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
54+
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
55+
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
56+
};
57+
58+
private static final String CALL_INFERENCE_SQL_TEMPLATE =
59+
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)";
60+
private static final int DEFAULT_INPUT_LENGTH = 256;
61+
private static final int DEFAULT_OUTPUT_LENGTH = 48;
62+
63+
@BeforeClass
64+
public static void setUp() throws Exception {
65+
// Init 1C1D1A cluster environment
66+
EnvFactory.getEnv().initClusterEnvironment(1, 1);
67+
prepareData(WRITE_SQL_IN_TREE);
68+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
69+
Statement statement = connection.createStatement()) {
70+
for (int i = 0; i < 2880; i++) {
71+
statement.execute(
72+
String.format(
73+
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
74+
i, (float) i, (double) i, i, i));
75+
}
76+
}
77+
}
78+
79+
@AfterClass
80+
public static void tearDown() throws Exception {
81+
EnvFactory.getEnv().cleanClusterEnvironment();
82+
}
83+
84+
@Test
85+
public void callInferenceTest() throws SQLException {
86+
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) {
87+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
88+
Statement statement = connection.createStatement()) {
89+
callInferenceTest(statement, modelInfo);
90+
}
91+
}
92+
}
93+
94+
public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
95+
throws SQLException {
96+
// Invoke call inference for specified models, there should exist result.
97+
for (int i = 0; i < 4; i++) {
98+
String callInferenceSQL =
99+
String.format(
100+
CALL_INFERENCE_SQL_TEMPLATE,
101+
modelInfo.getModelId(),
102+
i,
103+
DEFAULT_INPUT_LENGTH,
104+
DEFAULT_OUTPUT_LENGTH);
105+
try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
106+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
107+
checkHeader(resultSetMetaData, "Time,output");
108+
int count = 0;
109+
while (resultSet.next()) {
110+
count++;
111+
}
112+
// Ensure the call inference return results
113+
Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count);
114+
}
115+
}
116+
}
117+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.ainode.it;
21+
22+
import org.apache.iotdb.ainode.utils.AINodeTestUtils;
23+
import org.apache.iotdb.it.env.EnvFactory;
24+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
25+
import org.apache.iotdb.itbase.category.AIClusterIT;
26+
import org.apache.iotdb.itbase.env.BaseEnv;
27+
28+
import org.junit.AfterClass;
29+
import org.junit.BeforeClass;
30+
import org.junit.Test;
31+
import org.junit.experimental.categories.Category;
32+
import org.junit.runner.RunWith;
33+
import org.slf4j.Logger;
34+
import org.slf4j.LoggerFactory;
35+
36+
import java.sql.Connection;
37+
import java.sql.SQLException;
38+
import java.sql.Statement;
39+
40+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP;
41+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
42+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
43+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
44+
45+
@RunWith(IoTDBTestRunner.class)
46+
@Category({AIClusterIT.class})
47+
public class AINodeConcurrentForecastIT {
48+
49+
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
50+
51+
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
52+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)";
53+
54+
@BeforeClass
55+
public static void setUp() throws Exception {
56+
// Init 1C1D1A cluster environment
57+
EnvFactory.getEnv().initClusterEnvironment(1, 1);
58+
prepareDataForTableModel();
59+
}
60+
61+
@AfterClass
62+
public static void tearDown() throws Exception {
63+
EnvFactory.getEnv().cleanClusterEnvironment();
64+
}
65+
66+
private static void prepareDataForTableModel() throws SQLException {
67+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
68+
Statement statement = connection.createStatement()) {
69+
statement.execute("CREATE DATABASE root");
70+
statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)");
71+
for (int i = 0; i < 2880; i++) {
72+
statement.execute(
73+
String.format(
74+
"INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440)));
75+
}
76+
}
77+
}
78+
79+
@Test
80+
public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
81+
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) {
82+
concurrentGPUForecastTest(modelInfo);
83+
}
84+
}
85+
86+
public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo)
87+
throws SQLException, InterruptedException {
88+
final int forecastLength = 512;
89+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
90+
Statement statement = connection.createStatement()) {
91+
// Single forecast request can be processed successfully
92+
final String forecastSQL =
93+
String.format(
94+
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength);
95+
final int threadCnt = 10;
96+
final int loop = 100;
97+
final String devices = "0,1";
98+
statement.execute(
99+
String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices));
100+
checkModelOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
101+
long startTime = System.currentTimeMillis();
102+
concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength);
103+
long endTime = System.currentTimeMillis();
104+
LOGGER.info(
105+
String.format(
106+
"Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
107+
modelInfo.getModelId(), threadCnt * loop, threadCnt, loop, endTime - startTime));
108+
statement.execute(
109+
String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelInfo.getModelId(), devices));
110+
checkModelNotOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
111+
}
112+
}
113+
}

0 commit comments

Comments
 (0)