Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
package org.apache.iotdb.ainode.it;

import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.it.framework.IoTDBTestRunner;
import org.apache.iotdb.itbase.category.AIClusterIT;
import org.apache.iotdb.itbase.env.BaseEnv;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -36,21 +39,17 @@
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
public class AINodeConcurrentInferenceIT {

private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class);

private static final Map<String, String> MODEL_ID_TO_TYPE_MAP =
ImmutableMap.of(
"timer_xl", "Timer-XL",
"sundial", "Timer-Sundial");

@BeforeClass
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
Expand Down Expand Up @@ -86,13 +85,12 @@ private static void prepareDataForTableModel() throws SQLException {
for (int i = 0; i < 2880; i++) {
statement.execute(
String.format(
"INSERT INTO root.AI(timestamp, s) VALUES(%d, %f)",
i, Math.sin(i * Math.PI / 1440)));
"INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440)));
}
}
}

@Test
// @Test
public void concurrentCPUCallInferenceTest() throws SQLException, InterruptedException {
concurrentCPUCallInferenceTest("timer_xl");
concurrentCPUCallInferenceTest("sundial");
Expand All @@ -105,21 +103,21 @@ private void concurrentCPUCallInferenceTest(String modelId)
final int threadCnt = 4;
final int loop = 10;
final int predictLength = 96;
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
statement.execute(String.format("LOAD MODEL %s TO DEVICES 'cpu'", modelId));
checkModelOnSpecifiedDevice(statement, modelId, "cpu");
concurrentInference(
statement,
String.format(
"CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)",
"CALL INFERENCE(%s, 'SELECT s FROM root.AI', predict_length=%d)",
modelId, predictLength),
threadCnt,
loop,
predictLength);
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId));
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES 'cpu'", modelId));
}
}

@Test
// @Test
public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException {
concurrentGPUCallInferenceTest("timer_xl");
concurrentGPUCallInferenceTest("sundial");
Expand All @@ -133,17 +131,17 @@ private void concurrentGPUCallInferenceTest(String modelId)
final int loop = 100;
final int predictLength = 512;
final String devices = "0,1";
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices));
checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices));
checkModelOnSpecifiedDevice(statement, modelId, devices);
concurrentInference(
statement,
String.format(
"CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)",
"CALL INFERENCE(%s, 'SELECT s FROM root.AI', predict_length=%d)",
modelId, predictLength),
threadCnt,
loop,
predictLength);
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId));
}
}

Expand All @@ -159,8 +157,8 @@ private void concurrentCPUForecastTest(String modelId) throws SQLException, Inte
final int threadCnt = 4;
final int loop = 10;
final int predictLength = 96;
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
statement.execute(String.format("LOAD MODEL %s TO DEVICES 'cpu'", modelId));
checkModelOnSpecifiedDevice(statement, modelId, "cpu");
long startTime = System.currentTimeMillis();
concurrentInference(
statement,
Expand All @@ -175,7 +173,7 @@ private void concurrentCPUForecastTest(String modelId) throws SQLException, Inte
String.format(
"Model %s concurrent inference %d reqs (%d threads, %d loops) in CPU takes time: %dms",
modelId, threadCnt * loop, threadCnt, loop, endTime - startTime));
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId));
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES 'cpu'", modelId));
}
}

Expand All @@ -192,8 +190,8 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
final int loop = 100;
final int predictLength = 512;
final String devices = "0,1";
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices));
checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices));
checkModelOnSpecifiedDevice(statement, modelId, devices);
long startTime = System.currentTimeMillis();
concurrentInference(
statement,
Expand All @@ -208,32 +206,35 @@ public void concurrentGPUForecastTest(String modelId) throws SQLException, Inter
String.format(
"Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
modelId, threadCnt * loop, threadCnt, loop, endTime - startTime));
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId));
}
}

private void checkModelOnSpecifiedDevice(Statement statement, String modelType, String device)
private void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device)
throws SQLException, InterruptedException {
for (int retry = 0; retry < 10; retry++) {
Set<String> targetDevices = ImmutableSet.copyOf(device.split(","));
Set<String> targetDevices = ImmutableSet.copyOf(device.split(","));
LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices);
for (int retry = 0; retry < 20; retry++) {
Set<String> foundDevices = new HashSet<>();
try (final ResultSet resultSet =
statement.executeQuery(String.format("SHOW LOADED MODELS %s", device))) {
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
while (resultSet.next()) {
String deviceId = resultSet.getString(1);
String loadedModelType = resultSet.getString(2);
int count = resultSet.getInt(3);
if (loadedModelType.equals(modelType) && targetDevices.contains(deviceId)) {
Assert.assertTrue(count > 1);
String deviceId = resultSet.getString("DeviceId");
String loadedModelId = resultSet.getString("ModelId");
int count = resultSet.getInt("Count(instances)");
LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count);
if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) {
foundDevices.add(deviceId);
LOGGER.info("Model {} is loaded to device {}", modelId, device);
}
}
if (foundDevices.containsAll(targetDevices)) {
LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices);
return;
}
}
TimeUnit.SECONDS.sleep(3);
}
Assert.fail("Model " + modelType + " is not loaded on device " + device);
Assert.fail("Model " + modelId + " is not loaded on device " + device);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public static void tearDown() throws Exception {
EnvFactory.getEnv().cleanClusterEnvironment();
}

@Test
// @Test
public void callInferenceTestInTree() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
Expand Down Expand Up @@ -209,7 +209,7 @@ public void callInferenceTest(Statement statement) throws SQLException {
// }
}

@Test
// @Test
public void errorCallInferenceTestInTree() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
)
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.manager.model_manager import ModelManager
from iotdb.ainode.core.model.model_enums import BuiltInModelType
from iotdb.ainode.core.model.model_info import ModelInfo
from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device


Expand All @@ -58,7 +60,7 @@ class InferenceRequestPool(mp.Process):
def __init__(
self,
pool_id: int,
model_id: str,
model_info: ModelInfo,
device: str,
config: PretrainedConfig,
request_queue: mp.Queue,
Expand All @@ -68,7 +70,7 @@ def __init__(
):
super().__init__()
self.pool_id = pool_id
self.model_id = model_id
self.model_info = model_info
self.config = config
self.pool_kwargs = pool_kwargs
self.ready_event = ready_event
Expand Down Expand Up @@ -121,7 +123,7 @@ def _step(self):

for requests in grouped_requests:
batch_inputs = self._batcher.batch_request(requests).to(self.device)
if self.model_id == "sundial":
if self.model_info.model_type == BuiltInModelType.SUNDIAL.value:
batch_output = self._model.generate(
batch_inputs,
max_new_tokens=requests[0].max_new_tokens,
Expand All @@ -135,8 +137,7 @@ def _step(self):
cur_batch_size = request.batch_size
cur_output = batch_output[offset : offset + cur_batch_size]
offset += cur_batch_size
# TODO Here we only considered the case where batchsize=1 in one request. If multi-variable adaptation is required in the future, modifications may be needed here, such as: `cur_output[0]` maybe not true in multi-variable scene
request.write_step_output(cur_output[0].mean(dim=0))
request.write_step_output(cur_output.mean(dim=1))

request.inference_pipeline.post_decode()
if request.is_finished():
Expand All @@ -153,7 +154,7 @@ def _step(self):
)
self._waiting_queue.put(request)

elif self.model_id == "timer_xl":
elif self.model_info.model_type == BuiltInModelType.TIMER_XL.value:
batch_output = self._model.generate(
batch_inputs,
max_new_tokens=requests[0].max_new_tokens,
Expand Down Expand Up @@ -194,7 +195,9 @@ def run(self):
)
self._model_manager = ModelManager()
self._request_scheduler.device = self.device
self._model = self._model_manager.load_model(self.model_id, {}).to(self.device)
self._model = self._model_manager.load_model(self.model_info.model_id, {}).to(
self.device
)
self.ready_event.set()

activate_daemon = threading.Thread(
Expand All @@ -207,10 +210,13 @@ def run(self):
)
self._threads.append(execute_daemon)
execute_daemon.start()
self._logger.info(
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_info.model_id} is activated."
)
for thread in self._threads:
thread.join()
self._logger.info(
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_id} exited cleanly."
f"[Inference][Device-{self.device}][Pool-{self.pool_id}] InferenceRequestPool for model {self.model_info.model_id} exited cleanly."
)

def stop(self):
Expand Down
21 changes: 15 additions & 6 deletions iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
ScaleActionType,
)
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.manager.model_manager import ModelManager
from iotdb.ainode.core.model.model_enums import BuiltInModelType
from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig
from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig
from iotdb.ainode.core.util.atmoic_int import AtomicInt
Expand All @@ -48,6 +50,7 @@
from iotdb.ainode.core.util.thread_name import ThreadName

logger = Logger()
MODEL_MANAGER = ModelManager()


class PoolController:
Expand Down Expand Up @@ -169,7 +172,7 @@ def show_loaded_models(
for model_id, device_map in self._request_pool_map.items():
if device_id in device_map:
pool_group = device_map[device_id]
device_models[model_id] = pool_group.get_pool_count()
device_models[model_id] = pool_group.get_running_pool_count()
result[device_id] = device_models
return result

Expand All @@ -191,7 +194,7 @@ def _load_model_task(self, model_id: str, device_id_list: list[str]):
def _load_model_on_device_task(device_id: str):
if not self.has_request_pools(model_id, device_id):
actions = self._pool_scheduler.schedule_load_model_to_device(
model_id, device_id
MODEL_MANAGER.get_model_info(model_id), device_id
)
for action in actions:
if action.action == ScaleActionType.SCALE_UP:
Expand All @@ -218,7 +221,7 @@ def _unload_model_task(self, model_id: str, device_id_list: list[str]):
def _unload_model_on_device_task(device_id: str):
if self.has_request_pools(model_id, device_id):
actions = self._pool_scheduler.schedule_unload_model_from_device(
model_id, device_id
MODEL_MANAGER.get_model_info(model_id), device_id
)
for action in actions:
if action.action == ScaleActionType.SCALE_DOWN:
Expand Down Expand Up @@ -253,13 +256,19 @@ def _expand_pools_on_device(self, model_id: str, device_id: str, count: int):
def _expand_pool_on_device(*_):
result_queue = mp.Queue()
pool_id = self._new_pool_id.get_and_increment()
if model_id == "sundial":
model_info = MODEL_MANAGER.get_model_info(model_id)
model_type = model_info.model_type
if model_type == BuiltInModelType.SUNDIAL.value:
config = SundialConfig()
elif model_id == "timer_xl":
elif model_type == BuiltInModelType.TIMER_XL.value:
config = TimerConfig()
else:
raise InferenceModelInternalError(
f"Unsupported model type {model_type} for loading model {model_id}"
)
pool = InferenceRequestPool(
pool_id=pool_id,
model_id=model_id,
model_info=model_info,
device=device_id,
config=config,
request_queue=result_queue,
Expand Down
6 changes: 6 additions & 0 deletions iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def get_pool_ids(self) -> list[int]:
def get_pool_count(self) -> int:
return len(self.pool_group)

def get_running_pool_count(self) -> int:
count = 0
for _, state in self.pool_states.items():
count += 1 if state == PoolState.RUNNING else 0
return count

def dispatch_request(
self, req: InferenceRequest, infer_proxy: InferenceRequestProxy
):
Expand Down
Loading
Loading