Skip to content

Commit a104fdb

Browse files
jtmerCRZbulabula
authored andcommitted
[AINode] Add a batcher for inference (#16411)
(cherry picked from commit 7734331)
1 parent ec5821a commit a104fdb

File tree

2 files changed

+57
-17
lines changed

2 files changed

+57
-17
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -81,61 +81,103 @@ private static void prepareDataForTableModel() throws SQLException {
8181

8282
@Test
8383
public void concurrentCPUCallInferenceTest() throws SQLException, InterruptedException {
84+
concurrentCPUCallInferenceTest("timer_xl");
85+
concurrentCPUCallInferenceTest("sundial");
86+
}
87+
88+
private void concurrentCPUCallInferenceTest(String modelId)
89+
throws SQLException, InterruptedException {
8490
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
8591
Statement statement = connection.createStatement()) {
86-
statement.execute("LOAD MODEL sundial TO DEVICES \"cpu\"");
87-
concurrentInference(statement, "CALL INFERENCE(sundial, \"SELECT s FROM root.AI\")", 4, 10);
92+
final int threadCnt = 4;
93+
final int loop = 10;
94+
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
95+
concurrentInference(
96+
statement,
97+
String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId),
98+
threadCnt,
99+
loop);
100+
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId));
88101
}
89102
}
90103

91104
@Test
92105
public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException {
106+
concurrentGPUCallInferenceTest("timer_xl");
107+
concurrentGPUCallInferenceTest("sundial");
108+
}
109+
110+
private void concurrentGPUCallInferenceTest(String modelId)
111+
throws SQLException, InterruptedException {
93112
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
94113
Statement statement = connection.createStatement()) {
95-
statement.execute("LOAD MODEL sundial TO DEVICES \"0,1\"");
96-
concurrentInference(statement, "CALL INFERENCE(sundial, \"SELECT s FROM root.AI\")", 10, 100);
114+
final int threadCnt = 4;
115+
final int loop = 10;
116+
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId));
117+
concurrentInference(
118+
statement,
119+
String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId),
120+
threadCnt,
121+
loop);
122+
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
97123
}
98124
}
99125

100126
@Test
101127
public void concurrentCPUForecastTest() throws SQLException, InterruptedException {
102-
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
128+
concurrentCPUForecastTest("timer_xl");
129+
concurrentCPUForecastTest("sundial");
130+
}
131+
132+
private void concurrentCPUForecastTest(String modelId) throws SQLException, InterruptedException {
133+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
103134
Statement statement = connection.createStatement()) {
104135
final int threadCnt = 4;
105136
final int loop = 10;
106-
statement.execute("LOAD MODEL sundial TO DEVICES \"cpu\"");
137+
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
107138
long startTime = System.currentTimeMillis();
108139
concurrentInference(
109140
statement,
110-
"SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
141+
String.format(
142+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
143+
modelId),
111144
threadCnt,
112145
loop);
113146
long endTime = System.currentTimeMillis();
114147
LOGGER.info(
115148
String.format(
116-
"Timer-Sundial concurrent inference %d reqs (%d threads, %d loops) in CPU takes time: %dms",
117-
threadCnt * loop, threadCnt, loop, endTime - startTime));
149+
"Model %s concurrent inference %d reqs (%d threads, %d loops) in CPU takes time: %dms",
150+
modelId, threadCnt * loop, threadCnt, loop, endTime - startTime));
151+
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId));
118152
}
119153
}
120154

121155
@Test
122156
public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
123-
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
157+
concurrentGPUForecastTest("timer_xl");
158+
concurrentGPUForecastTest("sundial");
159+
}
160+
161+
public void concurrentGPUForecastTest(String modelId) throws SQLException, InterruptedException {
162+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
124163
Statement statement = connection.createStatement()) {
125164
final int threadCnt = 10;
126165
final int loop = 100;
127-
statement.execute("LOAD MODEL sundial TO DEVICES \"0,1\"");
166+
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId));
128167
long startTime = System.currentTimeMillis();
129168
concurrentInference(
130169
statement,
131-
"SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
170+
String.format(
171+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
172+
modelId),
132173
threadCnt,
133174
loop);
134175
long endTime = System.currentTimeMillis();
135176
LOGGER.info(
136177
String.format(
137-
"Timer-Sundial concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
138-
threadCnt * loop, threadCnt, loop, endTime - startTime));
178+
"Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
179+
modelId, threadCnt * loop, threadCnt, loop, endTime - startTime));
180+
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
139181
}
140182
}
141183
}

iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919

2020
import torch.multiprocessing as mp
2121

22-
from iotdb.ainode.core.exception import (
23-
InferenceModelInternalError,
24-
)
22+
from iotdb.ainode.core.exception import InferenceModelInternalError
2523
from iotdb.ainode.core.inference.dispatcher.basic_dispatcher import BasicDispatcher
2624
from iotdb.ainode.core.inference.inference_request import (
2725
InferenceRequest,

0 commit comments

Comments
 (0)