Skip to content

Commit f71aabf

Browse files
authored
[AINode] Integrate device manager framework (#16998)
1 parent 4ea03a5 commit f71aabf

File tree

31 files changed

+658
-253
lines changed

31 files changed

+658
-253
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@ private static void prepareDataForTableModel() throws SQLException {
8383
}
8484

8585
@Test
86-
public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
86+
public void concurrentForecastTest() throws SQLException, InterruptedException {
8787
for (AINodeTestUtils.FakeModelInfo modelInfo : MODEL_LIST) {
88-
concurrentGPUForecastTest(modelInfo);
88+
concurrentGPUForecastTest(modelInfo, "0,1");
89+
// TODO: Enable cpu test after optimize memory consumption
90+
// concurrentGPUForecastTest(modelInfo, "cpu");
8991
}
9092
}
9193

92-
public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo)
94+
public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo, String devices)
9395
throws SQLException, InterruptedException {
9496
final int forecastLength = 512;
9597
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
@@ -100,7 +102,6 @@ public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo)
100102
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength);
101103
final int threadCnt = 10;
102104
final int loop = 100;
103-
final String devices = "0,1";
104105
statement.execute(
105106
String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices));
106107
checkModelOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.ainode.it;
21+
22+
import org.apache.iotdb.it.env.EnvFactory;
23+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
24+
import org.apache.iotdb.itbase.category.AIClusterIT;
25+
import org.apache.iotdb.itbase.env.BaseEnv;
26+
27+
import org.junit.AfterClass;
28+
import org.junit.Assert;
29+
import org.junit.BeforeClass;
30+
import org.junit.Test;
31+
import org.junit.experimental.categories.Category;
32+
import org.junit.runner.RunWith;
33+
34+
import java.sql.Connection;
35+
import java.sql.ResultSet;
36+
import java.sql.ResultSetMetaData;
37+
import java.sql.SQLException;
38+
import java.sql.Statement;
39+
import java.util.Arrays;
40+
import java.util.LinkedList;
41+
import java.util.List;
42+
43+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
44+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable;
45+
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree;
46+
47+
@RunWith(IoTDBTestRunner.class)
48+
@Category({AIClusterIT.class})
49+
public class AINodeDeviceManageIT {
50+
51+
@BeforeClass
52+
public static void setUp() throws Exception {
53+
// Init 1C1D1A cluster environment
54+
EnvFactory.getEnv().initClusterEnvironment(1, 1);
55+
prepareDataInTree();
56+
prepareDataInTable();
57+
}
58+
59+
@AfterClass
60+
public static void tearDown() throws Exception {
61+
EnvFactory.getEnv().cleanClusterEnvironment();
62+
}
63+
64+
@Test
65+
public void showAIDeviceTestInTree() throws SQLException {
66+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
67+
Statement statement = connection.createStatement()) {
68+
showAIDevicesTest(statement);
69+
}
70+
}
71+
72+
@Test
73+
public void showAIDeviceTestInTable() throws SQLException {
74+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
75+
Statement statement = connection.createStatement()) {
76+
showAIDevicesTest(statement);
77+
}
78+
}
79+
80+
private void showAIDevicesTest(Statement statement) throws SQLException {
81+
final String showSql = "SHOW AI_DEVICES";
82+
final List<String> expectedDeviceIdList = new LinkedList<>(Arrays.asList("0", "1", "cpu"));
83+
final List<String> expectedDeviceTypeList =
84+
new LinkedList<>(Arrays.asList("cuda", "cuda", "cpu"));
85+
try (ResultSet resultSet = statement.executeQuery(showSql)) {
86+
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
87+
checkHeader(resultSetMetaData, "DeviceId,DeviceType");
88+
while (resultSet.next()) {
89+
String deviceId = resultSet.getString(1);
90+
String deviceType = resultSet.getString(2);
91+
Assert.assertEquals(expectedDeviceIdList.remove(0), deviceId);
92+
Assert.assertEquals(expectedDeviceTypeList.remove(0), deviceType);
93+
}
94+
}
95+
}
96+
}

iotdb-core/ainode/iotdb/ainode/core/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"timer": 856 * 1024**2, # 856 MiB
5757
} # the memory usage of each model in bytes
5858

59-
AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference
59+
AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.2 # the device space allocated for inference
6060
AINODE_INFERENCE_EXTRA_MEMORY_RATIO = (
6161
1.2 # the overhead ratio for inference, used to estimate the pool size
6262
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
from enum import Enum
20+
from typing import ContextManager, Optional, Protocol
21+
22+
import torch
23+
24+
25+
class BackendType(Enum):
26+
"""
27+
Different types of supported computation backends.
28+
AINode will automatically select the available backend according to the order defined here.
29+
"""
30+
31+
CUDA = "cuda"
32+
CPU = "cpu"
33+
34+
35+
class BackendAdapter(Protocol):
36+
type: BackendType
37+
38+
# device basics
39+
def is_available(self) -> bool: ...
40+
def device_count(self) -> int: ...
41+
def make_device(self, index: Optional[int]) -> torch.device: ...
42+
def set_device(self, index: int) -> None: ...
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
import torch
20+
21+
from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
22+
23+
24+
class CPUBackend(BackendAdapter):
25+
type = BackendType.CPU
26+
27+
def is_available(self) -> bool:
28+
return True
29+
30+
def device_count(self) -> int:
31+
return 1
32+
33+
def make_device(self, index: int | None) -> torch.device:
34+
return torch.device("cpu")
35+
36+
def set_device(self, index: int) -> None:
37+
return None
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
import torch
20+
21+
from iotdb.ainode.core.device.backend.base import BackendAdapter, BackendType
22+
23+
24+
class CUDABackend(BackendAdapter):
25+
type = BackendType.CUDA
26+
27+
def is_available(self) -> bool:
28+
return torch.cuda.is_available()
29+
30+
def device_count(self) -> int:
31+
return torch.cuda.device_count()
32+
33+
def make_device(self, index: int | None) -> torch.device:
34+
if index is None:
35+
raise ValueError("CUDA backend requires a valid device index")
36+
return torch.device(f"cuda:{index}")
37+
38+
def set_device(self, index: int) -> None:
39+
torch.cuda.set_device(index)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
from dataclasses import dataclass
19+
from typing import Optional, Union
20+
21+
import torch
22+
23+
DeviceLike = Union[torch.device, str, int]
24+
25+
26+
@dataclass(frozen=True)
27+
class DeviceSpec:
28+
type: str
29+
index: Optional[int]
30+
31+
32+
def parse_device_like(x: DeviceLike) -> DeviceSpec:
33+
if isinstance(x, int):
34+
return DeviceSpec("index", x)
35+
36+
if isinstance(x, str):
37+
try:
38+
return DeviceSpec("index", int(x))
39+
except ValueError:
40+
s = x.strip().lower()
41+
if ":" in s:
42+
t, idx = s.split(":", 1)
43+
return DeviceSpec(t, int(idx))
44+
return DeviceSpec(s, None)
45+
46+
if isinstance(x, torch.device):
47+
return DeviceSpec(x.type, x.index)
48+
49+
raise TypeError(f"Unsupported device: {x!r}")
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
#
18+
19+
import os
20+
from dataclasses import dataclass
21+
22+
23+
@dataclass(frozen=True)
24+
class DistEnv:
25+
rank: int
26+
local_rank: int
27+
world_size: int
28+
29+
30+
def read_dist_env() -> DistEnv:
31+
# torchrun:
32+
rank = int(os.environ.get("RANK", "0"))
33+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
34+
35+
# torchrun provides LOCAL_RANK; slurm often provides SLURM_LOCALID
36+
local_rank = os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", "0"))
37+
local_rank = int(local_rank)
38+
39+
return DistEnv(rank=rank, local_rank=local_rank, world_size=world_size)

0 commit comments

Comments
 (0)