Skip to content

Commit 12cd8da

Browse files
committed
bug fix & append it
1 parent b1e299b commit 12cd8da

File tree

3 files changed

+105
-9
lines changed

3 files changed

+105
-9
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.ainode.it;
21+
22+
import org.apache.iotdb.it.env.EnvFactory;
23+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
24+
import org.apache.iotdb.itbase.category.AIClusterIT;
25+
import org.apache.iotdb.itbase.env.BaseEnv;
26+
27+
import org.junit.AfterClass;
28+
import org.junit.Assert;
29+
import org.junit.BeforeClass;
30+
import org.junit.Test;
31+
import org.junit.experimental.categories.Category;
32+
import org.junit.runner.RunWith;
33+
34+
import java.sql.Connection;
35+
import java.sql.ResultSet;
36+
import java.sql.ResultSetMetaData;
37+
import java.sql.SQLException;
38+
import java.sql.Statement;
39+
import java.util.Arrays;
40+
import java.util.List;
41+
42+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
43+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
44+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
45+
46+
@RunWith(IoTDBTestRunner.class)
47+
@Category({AIClusterIT.class})
48+
public class AINodeDeviceManageIT {
49+
50+
@BeforeClass
51+
public static void setUp() throws Exception {
52+
// Init 1C1D1A cluster environment
53+
EnvFactory.getEnv().initClusterEnvironment(1, 1);
54+
prepareDataInTree();
55+
prepareDataInTable();
56+
}
57+
58+
@AfterClass
59+
public static void tearDown() throws Exception {
60+
EnvFactory.getEnv().cleanClusterEnvironment();
61+
}
62+
63+
@Test
64+
public void showAIDeviceTestInTree() throws SQLException {
65+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
66+
Statement statement = connection.createStatement()) {
67+
showAIDevicesTest(statement);
68+
}
69+
}
70+
71+
@Test
72+
public void showAIDeviceTestInTable() throws SQLException {
73+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
74+
Statement statement = connection.createStatement()) {
75+
showAIDevicesTest(statement);
76+
}
77+
}
78+
79+
private void showAIDevicesTest(Statement statement) throws SQLException {
80+
final String showSql = "SHOW AI_DEVICES";
81+
final List<String> expectedDeviceIdList = Arrays.asList("0", "1", "cpu");
82+
final List<String> expectedDeviceTypeList = Arrays.asList("cuda", "cuda", "cpu");
83+
try (ResultSet resultSet = statement.executeQuery(showSql)) {
84+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
85+
checkHeader(resultSetMetaData, "DeviceId,DeviceType");
86+
while (resultSet.next()) {
87+
String deviceId = resultSet.getString(1);
88+
String deviceType = resultSet.getString(2);
89+
Assert.assertEquals(expectedDeviceIdList.remove(0), deviceId);
90+
Assert.assertEquals(expectedDeviceTypeList.remove(0), deviceType);
91+
}
92+
}
93+
}
94+
}

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ def __init__(
7777
self.ready_event = ready_event
7878
self.device = device
7979

80-
self._backend = DeviceManager()
81-
8280
self._threads = []
8381
self._waiting_queue = request_queue # Requests that are waiting to be processed
8482
self._running_queue = mp.Queue() # Requests that are currently being processed
@@ -89,8 +87,8 @@ def __init__(
8987
self._batcher = BasicBatcher()
9088
self._stop_event = mp.Event()
9189

90+
self._backend = None
9291
self._inference_pipeline = None
93-
9492
self._logger = None
9593

9694
# Fix inference seed
@@ -186,6 +184,7 @@ def run(self):
186184
self._logger = Logger(
187185
INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
188186
)
187+
self._backend = DeviceManager()
189188
self._request_scheduler.device = self.device
190189
self._inference_pipeline = load_pipeline(self.model_info, self.device)
191190
self.ready_event.set()

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowAIDevicesTask.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,15 @@ public static void buildTsBlock(
5454
.map(ColumnHeader::getColumnType)
5555
.collect(Collectors.toList());
5656
TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes);
57-
for (Map.Entry<String, String> deviceEntry : resp.getDeviceIdMap().entrySet()) {
58-
builder.getTimeColumnBuilder().writeLong(0L);
59-
builder.getColumnBuilder(0).writeBinary(BytesUtils.valueOf(deviceEntry.getKey()));
60-
builder.getColumnBuilder(1).writeBinary(BytesUtils.valueOf(deviceEntry.getValue()));
61-
builder.declarePosition();
62-
}
57+
resp.getDeviceIdMap().entrySet().stream()
58+
.sorted(Map.Entry.comparingByKey())
59+
.forEach(
60+
deviceEntry -> {
61+
builder.getTimeColumnBuilder().writeLong(0L);
62+
builder.getColumnBuilder(0).writeBinary(BytesUtils.valueOf(deviceEntry.getKey()));
63+
builder.getColumnBuilder(1).writeBinary(BytesUtils.valueOf(deviceEntry.getValue()));
64+
builder.declarePosition();
65+
});
6366
DatasetHeader datasetHeader = DatasetHeaderFactory.getShowAIDevicesHeader();
6467
future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS, builder.build(), datasetHeader));
6568
}

0 commit comments

Comments
 (0)