Skip to content

Commit fbbc675

Browse files
CRZbulabulayunbow30944jtmerAlchuang22-dev
authored
[AINode][To rc/2.0.6] cp AINode codes (#16910)
* [AINode] Refactor code base * [AINode] Implement concurrent inference framework (#16311) (cherry picked from commit 7b9ec7e) * [AINode] Fix bugs for SHOW LOADED MODELS (#16410) (cherry picked from commit 40b2b33) * [AINode] Add a batcher for inference (#16411) (cherry picked from commit 7734331) * [AINode][Bug fix] Concurrent inference (#16518) * trigger CI * bug fix 4 show loaded models (cherry picked from commit b4dde12) * [AINode] Concurrent inference bug fix (#16595) (cherry picked from commit 46a0c6a) * [AINode] Adjust the maximum inference input length (#16640) (cherry picked from commit 2c9064f) * [AINode] Fix bug of sundial and forecast udf (#16768) (cherry picked from commit 2b47be7) * [AINode] Package AINode via PyInstaller (#16707) (cherry picked from commit 49c625b) * [AINode] Enable AINode start as background (-d) (#16762) (cherry picked from commit 1ebb951) * [AINode] Update AINodeClient for DataNode to borrow (#16647) (cherry picked from commit d49d7dd) * [AINode] Fix bug that AINode cannot compile in Windows (#16767) (cherry picked from commit cd443ba) * [AINode] Delete poetry.lock for easier maintain different operating systems (#16793) (cherry picked from commit 50f92e4) * [AINode] Fix cp errors --------- Co-authored-by: Leo <[email protected]> Co-authored-by: jtmer <[email protected]> Co-authored-by: Zeyu Zhang <[email protected]>
1 parent 460f3e9 commit fbbc675

File tree

192 files changed

+7655
-8291
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

192 files changed

+7655
-8291
lines changed

.github/workflows/cluster-it-1c1d1a.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ jobs:
4141

4242
steps:
4343
- uses: actions/checkout@v4
44-
- name: Build AINode
45-
shell: bash
46-
run: mvn clean package -DskipTests -P with-ainode
4744
- name: IT Test
4845
shell: bash
4946
run: |
@@ -59,5 +56,5 @@ jobs:
5956
uses: actions/upload-artifact@v4
6057
with:
6158
name: cluster-log-ainode-${{ matrix.os }}
62-
path: integration-test/target/ainode-logs
59+
path: integration-test/target/*-logs
6360
retention-days: 30

integration-test/src/assembly/mpp-test.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
</fileSet>
6464
<fileSet>
6565
<outputDirectory>lib</outputDirectory>
66-
<directory>${project.basedir}/../iotdb-core/ainode/dist/</directory>
66+
<directory>${project.basedir}/../iotdb-core/ainode/dist/ainode/</directory>
6767
<fileMode>0755</fileMode>
6868
</fileSet>
6969
</fileSets>

integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
import org.apache.commons.io.file.PathUtils;
2626
import org.slf4j.Logger;
2727

28-
import java.io.BufferedWriter;
2928
import java.io.File;
30-
import java.io.FileWriter;
29+
import java.io.FileInputStream;
30+
import java.io.FileOutputStream;
3131
import java.io.IOException;
3232
import java.nio.file.Files;
3333
import java.nio.file.LinkOption;
@@ -37,6 +37,7 @@
3737
import java.nio.file.StandardCopyOption;
3838
import java.util.ArrayList;
3939
import java.util.List;
40+
import java.util.Properties;
4041
import java.util.stream.Stream;
4142

4243
import static org.apache.iotdb.it.env.cluster.ClusterConstant.AI_NODE_NAME;
@@ -59,18 +60,22 @@ public class AINodeWrapper extends AbstractNodeWrapper {
5960
public static final String CONFIG_PATH = "conf";
6061
public static final String SCRIPT_PATH = "sbin";
6162
public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/weights";
62-
public static final String CACHE_BUILT_IN_MODEL_PATH = "/tmp/data/ainode/models/weights";
63+
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights";
6364

6465
private void replaceAttribute(String[] keys, String[] values, String filePath) {
65-
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filePath, true))) {
66-
for (int i = 0; i < keys.length; i++) {
67-
String line = keys[i] + "=" + values[i];
68-
writer.newLine();
69-
writer.write(line);
70-
}
66+
Properties props = new Properties();
67+
try (FileInputStream in = new FileInputStream(filePath)) {
68+
props.load(in);
69+
} catch (IOException e) {
70+
logger.warn("Failed to load existing AINode properties from {}, because: ", filePath, e);
71+
}
72+
for (int i = 0; i < keys.length; i++) {
73+
props.setProperty(keys[i], values[i]);
74+
}
75+
try (FileOutputStream out = new FileOutputStream(filePath)) {
76+
props.store(out, "Updated by AINode integration-test env");
7177
} catch (IOException e) {
72-
logger.error(
73-
"Failed to set attribute for AINode in file: {} because {}", filePath, e.getMessage());
78+
logger.error("Failed to save properties to {}, because:", filePath, e);
7479
}
7580
}
7681

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.ainode.it;
21+
22+
import org.apache.iotdb.it.env.EnvFactory;
23+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
24+
import org.apache.iotdb.itbase.category.AIClusterIT;
25+
import org.apache.iotdb.itbase.env.BaseEnv;
26+
27+
import com.google.common.collect.ImmutableSet;
28+
import org.junit.AfterClass;
29+
import org.junit.Assert;
30+
import org.junit.BeforeClass;
31+
import org.junit.Test;
32+
import org.junit.experimental.categories.Category;
33+
import org.junit.runner.RunWith;
34+
import org.slf4j.Logger;
35+
import org.slf4j.LoggerFactory;
36+
37+
import java.sql.Connection;
38+
import java.sql.ResultSet;
39+
import java.sql.SQLException;
40+
import java.sql.Statement;
41+
import java.util.HashSet;
42+
import java.util.Set;
43+
import java.util.concurrent.TimeUnit;
44+
45+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
46+
47+
@RunWith(IoTDBTestRunner.class)
48+
@Category({AIClusterIT.class})
49+
public class AINodeConcurrentInferenceIT {
50+
51+
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class);
52+
53+
@BeforeClass
54+
public static void setUp() throws Exception {
55+
// Init 1C1D1A cluster environment
56+
EnvFactory.getEnv().initClusterEnvironment(1, 1);
57+
prepareDataForTreeModel();
58+
prepareDataForTableModel();
59+
}
60+
61+
@AfterClass
62+
public static void tearDown() throws Exception {
63+
EnvFactory.getEnv().cleanClusterEnvironment();
64+
}
65+
66+
private static void prepareDataForTreeModel() throws SQLException {
67+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
68+
Statement statement = connection.createStatement()) {
69+
statement.execute("CREATE DATABASE root.AI");
70+
statement.execute("CREATE TIMESERIES root.AI.s WITH DATATYPE=DOUBLE, ENCODING=RLE");
71+
for (int i = 0; i < 2880; i++) {
72+
statement.execute(
73+
String.format(
74+
"INSERT INTO root.AI(timestamp, s) VALUES(%d, %f)",
75+
i, Math.sin(i * Math.PI / 1440)));
76+
}
77+
}
78+
}
79+
80+
private static void prepareDataForTableModel() throws SQLException {
81+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
82+
Statement statement = connection.createStatement()) {
83+
statement.execute("CREATE DATABASE root");
84+
statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)");
85+
for (int i = 0; i < 2880; i++) {
86+
statement.execute(
87+
String.format(
88+
"INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440)));
89+
}
90+
}
91+
}
92+
93+
// @Test
94+
public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException {
95+
concurrentGPUCallInferenceTest("timer_xl");
96+
concurrentGPUCallInferenceTest("sundial");
97+
}
98+
99+
private void concurrentGPUCallInferenceTest(String modelId)
100+
throws SQLException, InterruptedException {
101+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
102+
Statement statement = connection.createStatement()) {
103+
final int threadCnt = 10;
104+
final int loop = 100;
105+
final int predictLength = 512;
106+
final String devices = "0,1";
107+
statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices));
108+
checkModelOnSpecifiedDevice(statement, modelId, devices);
109+
concurrentInference(
110+
statement,
111+
String.format(
112+
"CALL INFERENCE(%s, 'SELECT s FROM root.AI', predict_length=%d)",
113+
modelId, predictLength),
114+
threadCnt,
115+
loop,
116+
predictLength);
117+
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId));
118+
}
119+
}
120+
121+
@Test
122+
public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
123+
concurrentGPUForecastTest("timer_xl");
124+
concurrentGPUForecastTest("sundial");
125+
}
126+
127+
public void concurrentGPUForecastTest(String modelId) throws SQLException, InterruptedException {
128+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
129+
Statement statement = connection.createStatement()) {
130+
final int threadCnt = 10;
131+
final int loop = 100;
132+
final int predictLength = 512;
133+
final String devices = "0,1";
134+
statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices));
135+
checkModelOnSpecifiedDevice(statement, modelId, devices);
136+
long startTime = System.currentTimeMillis();
137+
concurrentInference(
138+
statement,
139+
String.format(
140+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d",
141+
modelId, predictLength),
142+
threadCnt,
143+
loop,
144+
predictLength);
145+
long endTime = System.currentTimeMillis();
146+
LOGGER.info(
147+
String.format(
148+
"Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
149+
modelId, threadCnt * loop, threadCnt, loop, endTime - startTime));
150+
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId));
151+
}
152+
}
153+
154+
private void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device)
155+
throws SQLException, InterruptedException {
156+
Set<String> targetDevices = ImmutableSet.copyOf(device.split(","));
157+
LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices);
158+
for (int retry = 0; retry < 200; retry++) {
159+
Set<String> foundDevices = new HashSet<>();
160+
try (final ResultSet resultSet =
161+
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
162+
while (resultSet.next()) {
163+
String deviceId = resultSet.getString("DeviceId");
164+
String loadedModelId = resultSet.getString("ModelId");
165+
int count = resultSet.getInt("Count(instances)");
166+
LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count);
167+
if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) {
168+
foundDevices.add(deviceId);
169+
LOGGER.info("Model {} is loaded to device {}", modelId, device);
170+
}
171+
}
172+
if (foundDevices.containsAll(targetDevices)) {
173+
LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices);
174+
return;
175+
}
176+
}
177+
TimeUnit.SECONDS.sleep(3);
178+
}
179+
Assert.fail("Model " + modelId + " is not loaded on device " + device);
180+
}
181+
}

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public static void tearDown() throws Exception {
9595
EnvFactory.getEnv().cleanClusterEnvironment();
9696
}
9797

98-
@Test
98+
// @Test
9999
public void callInferenceTestInTree() throws SQLException {
100100
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
101101
Statement statement = connection.createStatement()) {
@@ -209,7 +209,7 @@ public void callInferenceTest(Statement statement) throws SQLException {
209209
// }
210210
}
211211

212-
@Test
212+
// @Test
213213
public void errorCallInferenceTestInTree() throws SQLException {
214214
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
215215
Statement statement = connection.createStatement()) {

0 commit comments

Comments
 (0)