Skip to content

Commit b4dde12

Browse files
authored
[AINode][Bug fix] Concurrent inference (apache#16518)
* trigger CI * bug fix 4 show loaded models
1 parent e03560f commit b4dde12

File tree

3 files changed

+76
-17
lines changed

3 files changed

+76
-17
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java

Lines changed: 70 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,35 @@
2222
import org.apache.iotdb.it.env.EnvFactory;
2323
import org.apache.iotdb.itbase.env.BaseEnv;
2424

25+
import com.google.common.collect.ImmutableMap;
26+
import com.google.common.collect.ImmutableSet;
2527
import org.junit.AfterClass;
28+
import org.junit.Assert;
2629
import org.junit.BeforeClass;
2730
import org.junit.Test;
2831
import org.slf4j.Logger;
2932
import org.slf4j.LoggerFactory;
3033

3134
import java.sql.Connection;
35+
import java.sql.ResultSet;
3236
import java.sql.SQLException;
3337
import java.sql.Statement;
38+
import java.util.HashSet;
39+
import java.util.Map;
40+
import java.util.Set;
41+
import java.util.concurrent.TimeUnit;
3442

3543
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
3644

3745
public class AINodeConcurrentInferenceIT {
3846

3947
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class);
4048

49+
private static final Map<String, String> MODEL_ID_TO_TYPE_MAP =
50+
ImmutableMap.of(
51+
"timer_xl", "Timer-XL",
52+
"sundial", "Timer-Sundial");
53+
4154
@BeforeClass
4255
public static void setUp() throws Exception {
4356
// Init 1C1D1A cluster environment
@@ -91,12 +104,17 @@ private void concurrentCPUCallInferenceTest(String modelId)
91104
Statement statement = connection.createStatement()) {
92105
final int threadCnt = 4;
93106
final int loop = 10;
107+
final int predictLength = 96;
94108
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
109+
checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
95110
concurrentInference(
96111
statement,
97-
String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId),
112+
String.format(
113+
"CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)",
114+
modelId, predictLength),
98115
threadCnt,
99-
loop);
116+
loop,
117+
predictLength);
100118
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId));
101119
}
102120
}
@@ -111,14 +129,20 @@ private void concurrentGPUCallInferenceTest(String modelId)
111129
throws SQLException, InterruptedException {
112130
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
113131
Statement statement = connection.createStatement()) {
114-
final int threadCnt = 4;
115-
final int loop = 10;
116-
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId));
132+
final int threadCnt = 10;
133+
final int loop = 100;
134+
final int predictLength = 512;
135+
final String devices = "0,1";
136+
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices));
137+
checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
117138
concurrentInference(
118139
statement,
119-
String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId),
140+
String.format(
141+
"CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)",
142+
modelId, predictLength),
120143
threadCnt,
121-
loop);
144+
loop,
145+
predictLength);
122146
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
123147
}
124148
}
@@ -134,15 +158,18 @@ private void concurrentCPUForecastTest(String modelId) throws SQLException, Inte
134158
Statement statement = connection.createStatement()) {
135159
final int threadCnt = 4;
136160
final int loop = 10;
161+
final int predictLength = 96;
137162
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
163+
checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
138164
long startTime = System.currentTimeMillis();
139165
concurrentInference(
140166
statement,
141167
String.format(
142-
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
143-
modelId),
168+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d",
169+
modelId, predictLength),
144170
threadCnt,
145-
loop);
171+
loop,
172+
predictLength);
146173
long endTime = System.currentTimeMillis();
147174
LOGGER.info(
148175
String.format(
@@ -163,15 +190,19 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
163190
Statement statement = connection.createStatement()) {
164191
final int threadCnt = 10;
165192
final int loop = 100;
166-
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId));
193+
final int predictLength = 512;
194+
final String devices = "0,1";
195+
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices));
196+
checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
167197
long startTime = System.currentTimeMillis();
168198
concurrentInference(
169199
statement,
170200
String.format(
171-
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
172-
modelId),
201+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d",
202+
modelId, predictLength),
173203
threadCnt,
174-
loop);
204+
loop,
205+
predictLength);
175206
long endTime = System.currentTimeMillis();
176207
LOGGER.info(
177208
String.format(
@@ -180,4 +211,29 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
180211
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
181212
}
182213
}
214+
215+
private void checkModelOnSpecifiedDevice(Statement statement, String modelType, String device)
216+
throws SQLException, InterruptedException {
217+
for (int retry = 0; retry < 10; retry++) {
218+
Set<String> targetDevices = ImmutableSet.copyOf(device.split(","));
219+
Set<String> foundDevices = new HashSet<>();
220+
try (final ResultSet resultSet =
221+
statement.executeQuery(String.format("SHOW LOADED MODELS %s", device))) {
222+
while (resultSet.next()) {
223+
String deviceId = resultSet.getString(1);
224+
String loadedModelType = resultSet.getString(2);
225+
int count = resultSet.getInt(3);
226+
if (loadedModelType.equals(modelType) && targetDevices.contains(deviceId)) {
227+
Assert.assertTrue(count > 1);
228+
foundDevices.add(deviceId);
229+
}
230+
}
231+
if (foundDevices.containsAll(targetDevices)) {
232+
return;
233+
}
234+
}
235+
TimeUnit.SECONDS.sleep(3);
236+
}
237+
Assert.fail("Model " + modelType + " is not loaded on device " + device);
238+
}
183239
}

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ public static void errorTest(Statement statement, String sql, String errorMessag
6060
}
6161
}
6262

63-
public static void concurrentInference(Statement statement, String sql, int threadCnt, int loop)
63+
public static void concurrentInference(
64+
Statement statement, String sql, int threadCnt, int loop, int expectedOutputLength)
6465
throws InterruptedException {
6566
Thread[] threads = new Thread[threadCnt];
6667
for (int i = 0; i < threadCnt; i++) {
@@ -70,9 +71,11 @@ public static void concurrentInference(Statement statement, String sql, int thre
7071
try {
7172
for (int j = 0; j < loop; j++) {
7273
try (ResultSet resultSet = statement.executeQuery(sql)) {
74+
int outputCnt = 0;
7375
while (resultSet.next()) {
74-
// do nothing
76+
outputCnt++;
7577
}
78+
assertEquals(expectedOutputLength, outputCnt);
7679
} catch (SQLException e) {
7780
fail(e.getMessage());
7881
}

iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def show_loaded_models(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp
217217
status=get_status(TSStatusCode.SUCCESS_STATUS),
218218
deviceLoadedModelsMap=self._pool_controller.show_loaded_models(
219219
req.deviceIdList
220-
if req.deviceIdList is not None
220+
if len(req.deviceIdList) > 0
221221
else get_available_devices()
222222
),
223223
)

0 commit comments

Comments
 (0)