Skip to content

Commit eec3029

Browse files
author
lexasub
committed
rpc: init(continue) queue for msg's - now is builded - server 2-threads created
1 parent 53debe6 commit eec3029

File tree

1 file changed

+314
-3
lines changed

1 file changed

+314
-3
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 314 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@
2626
# include <unistd.h>
2727
#endif
2828
#include <cstring>
29+
#include <variant>
30+
31+
#define RPC_QUEUE
32+
#ifdef RPC_QUEUE
33+
#include <condition_variable>
34+
#include <functional>
35+
#include <queue>
36+
#include <thread>
37+
#endif
2938

3039
#ifdef _WIN32
3140
typedef SOCKET sockfd_t;
@@ -160,6 +169,42 @@ struct rpc_msg_get_device_memory_rsp {
160169
};
161170
#pragma pack(pop)
162171

172+
#ifdef RPC_QUEUE
173+
struct rpc_task_t {
174+
rpc_cmd cmd;
175+
typedef std::variant<rpc_msg_alloc_buffer_req, rpc_msg_get_alloc_size_req,
176+
rpc_msg_buffer_get_base_req, rpc_msg_free_buffer_req,
177+
rpc_msg_buffer_clear_req, std::vector<uint8_t>,
178+
rpc_msg_get_tensor_req, rpc_msg_copy_tensor_req,
179+
rpc_msg_init_tensor_req> req_t;
180+
req_t req;
181+
std::variant<rpc_msg_alloc_buffer_rsp, rpc_msg_get_alloc_size_rsp,
182+
rpc_msg_get_alignment_rsp, rpc_msg_get_max_size_rsp,
183+
rpc_msg_buffer_get_base_rsp, std::vector<uint8_t>,
184+
rpc_msg_copy_tensor_rsp, rpc_msg_graph_compute_rsp,
185+
rpc_msg_get_device_memory_rsp, bool> rsp;
186+
sockfd_t sockfd;
187+
std::mutex response_mutex;
188+
rpc_task_t(rpc_task_t&& t) : cmd(t.cmd), req(t.req), rsp(t.rsp), sockfd(t.sockfd) {}
189+
rpc_task_t(rpc_cmd cmd, req_t req, sockfd_t sockfd) : cmd(cmd), req(req), sockfd(sockfd) {}
190+
};
191+
192+
struct rpc_queue_t {
193+
std::queue<rpc_task_t> tasks;
194+
std::mutex mutex;
195+
std::condition_variable cond;
196+
volatile bool running;
197+
};
198+
199+
struct rpc_worker_context {
200+
std::shared_ptr<rpc_queue_t> queue;
201+
ggml_backend_t backend;
202+
size_t free_mem;
203+
size_t total_mem;
204+
};
205+
void* process_queue(rpc_worker_context* ctx);
206+
#endif
207+
163208
// RPC data structures
164209

165210
static ggml_guid_t ggml_backend_rpc_guid() {
@@ -1125,7 +1170,8 @@ rpc_server::~rpc_server() {
11251170
ggml_backend_buffer_free(buffer);
11261171
}
11271172
}
1128-
1173+
#ifndef RPC_QUEUE
1174+
//if you change, then do synchronize change with such name function
11291175
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
11301176
rpc_server server(backend);
11311177
while (true) {
@@ -1312,7 +1358,133 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13121358
}
13131359
}
13141360
}
1361+
#else
13151362

1363+
//if you change, then do synchronize change with such name function
1364+
static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, std::shared_ptr<rpc_queue_t> _queue) {
1365+
auto queue = _queue.get();
1366+
rpc_server server(backend);
1367+
while (true) {
1368+
rpc_cmd cmd;
1369+
if (!recv_data(sockfd, &cmd, 1)) {
1370+
break;
1371+
}
1372+
if (cmd >= RPC_CMD_COUNT) {
1373+
// fail fast if the command is invalid
1374+
fprintf(stderr, "Unknown command: %d\n", cmd);
1375+
break;
1376+
}
1377+
rpc_task_t::req_t req;
1378+
switch (cmd) {
1379+
case RPC_CMD_ALLOC_BUFFER: {
1380+
rpc_msg_alloc_buffer_req request;
1381+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1382+
return;
1383+
}
1384+
req = request;
1385+
break;
1386+
}
1387+
case RPC_CMD_GET_ALLOC_SIZE: {
1388+
rpc_msg_get_alloc_size_req request;
1389+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1390+
return;
1391+
}
1392+
req = request;
1393+
break;
1394+
}
1395+
case RPC_CMD_GET_ALIGNMENT: {
1396+
if (!recv_msg(sockfd, nullptr, 0)) {
1397+
return;
1398+
}
1399+
break;
1400+
}
1401+
case RPC_CMD_GET_MAX_SIZE: {
1402+
if (!recv_msg(sockfd, nullptr, 0)) {
1403+
return;
1404+
}
1405+
break;
1406+
}
1407+
case RPC_CMD_BUFFER_GET_BASE: {
1408+
rpc_msg_buffer_get_base_req request;
1409+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1410+
return;
1411+
}
1412+
req = request;
1413+
break;
1414+
}
1415+
case RPC_CMD_FREE_BUFFER: {
1416+
rpc_msg_free_buffer_req request;
1417+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1418+
return;
1419+
}
1420+
req = request;
1421+
break;
1422+
}
1423+
case RPC_CMD_BUFFER_CLEAR: {
1424+
rpc_msg_buffer_clear_req request;
1425+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1426+
return;
1427+
}
1428+
req = request;
1429+
break;
1430+
}
1431+
case RPC_CMD_SET_TENSOR: {
1432+
std::vector<uint8_t> input;
1433+
if (!recv_msg(sockfd, input)) {
1434+
return;
1435+
}
1436+
req = input;
1437+
break;
1438+
}
1439+
case RPC_CMD_INIT_TENSOR: {
1440+
rpc_msg_init_tensor_req request;
1441+
if (!recv_msg(sockfd, &request,sizeof(request))) {
1442+
return;
1443+
}
1444+
req = request;
1445+
break;
1446+
}
1447+
case RPC_CMD_GET_TENSOR: {
1448+
rpc_msg_get_tensor_req request;
1449+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1450+
return;
1451+
}
1452+
req = request;
1453+
break;
1454+
}
1455+
case RPC_CMD_COPY_TENSOR: {
1456+
rpc_msg_copy_tensor_req request;
1457+
if (!recv_msg(sockfd, &request, sizeof(request))) {
1458+
return;
1459+
}
1460+
req = request;
1461+
break;
1462+
}
1463+
case RPC_CMD_GRAPH_COMPUTE: {
1464+
std::vector<uint8_t> input;
1465+
if (!recv_msg(sockfd, input)) {
1466+
return;
1467+
}
1468+
req = input;
1469+
break;
1470+
}
1471+
case RPC_CMD_GET_DEVICE_MEMORY: {
1472+
if (!recv_msg(sockfd, nullptr, 0)) {
1473+
return;
1474+
}
1475+
break;
1476+
}
1477+
default: {
1478+
fprintf(stderr, "Unknown command: %d\n", cmd);
1479+
return;
1480+
}
1481+
}
1482+
std::lock_guard<std::mutex> lock(queue->mutex);
1483+
queue->tasks.push(rpc_task_t{cmd, req, sockfd});
1484+
queue->cond.notify_one();
1485+
}
1486+
}
1487+
#endif
13161488
void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
13171489
std::string host;
13181490
int port;
@@ -1334,6 +1506,12 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13341506
fprintf(stderr, "Failed to create server socket\n");
13351507
return;
13361508
}
1509+
1510+
#ifdef RPC_QUEUE
1511+
std::shared_ptr<rpc_queue_t> queue;
1512+
queue->running = true;
1513+
std::thread worker_thread(process_queue, new rpc_worker_context{queue, backend, free_mem, total_mem});
1514+
#endif
13371515
while (true) {
13381516
auto client_socket = socket_accept(server_socket->fd);
13391517
if (client_socket == nullptr) {
@@ -1342,10 +1520,18 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13421520
}
13431521
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
13441522
fflush(stdout);
1523+
#ifndef RPC_QUEUE
13451524
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1525+
#else
1526+
rpc_serve_client(backend, client_socket->fd, queue);
1527+
#endif
13461528
printf("Client connection closed\n");
13471529
fflush(stdout);
13481530
}
1531+
queue->running = false;
1532+
#ifdef RPC_QUEUE
1533+
worker_thread.join();
1534+
#endif
13491535
#ifdef _WIN32
13501536
WSACleanup();
13511537
#endif
@@ -1494,7 +1680,134 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
14941680

14951681
return &ggml_backend_rpc_reg;
14961682
}
1683+
#ifdef RPC_QUEUE
1684+
bool send_response(const rpc_task_t& task);
1685+
bool send_response(const rpc_task_t& task) {
1686+
size_t response_size = 0;
1687+
const void* response_data = nullptr;
14971688

1689+
switch (task.cmd) {
1690+
case RPC_CMD_ALLOC_BUFFER:
1691+
response_data = &std::get<rpc_msg_alloc_buffer_rsp>(task.rsp);
1692+
response_size = sizeof(rpc_msg_alloc_buffer_rsp);
1693+
break;
1694+
case RPC_CMD_GET_ALIGNMENT:
1695+
response_data = &std::get<rpc_msg_get_alignment_rsp>(task.rsp);
1696+
response_size = sizeof(rpc_msg_get_alignment_rsp);
1697+
break;
1698+
case RPC_CMD_GET_MAX_SIZE:
1699+
response_data = &std::get<rpc_msg_get_max_size_rsp>(task.rsp);
1700+
response_size = sizeof(rpc_msg_get_max_size_rsp);
1701+
break;
1702+
case RPC_CMD_BUFFER_GET_BASE:
1703+
response_data = &std::get<rpc_msg_buffer_get_base_rsp>(task.rsp);
1704+
response_size = sizeof(rpc_msg_buffer_get_base_rsp);
1705+
break;
1706+
case RPC_CMD_COPY_TENSOR:
1707+
response_data = &std::get<rpc_msg_copy_tensor_rsp>(task.rsp);
1708+
response_size = sizeof(rpc_msg_copy_tensor_rsp);
1709+
break;
1710+
case RPC_CMD_INIT_TENSOR:
1711+
case RPC_CMD_SET_TENSOR:
1712+
response_data = &std::get<bool>(task.rsp);
1713+
response_size = 1;
1714+
break;
1715+
case RPC_CMD_GET_TENSOR:
1716+
response_data = std::get<std::vector<uint8_t>>(task.rsp).data();
1717+
response_size = std::get<std::vector<uint8_t>>(task.rsp).size();
1718+
break;
1719+
case RPC_CMD_GRAPH_COMPUTE:
1720+
response_data = &std::get<rpc_msg_graph_compute_rsp>(task.rsp).result;
1721+
response_size = sizeof(rpc_msg_graph_compute_rsp::result);
1722+
break;
1723+
case RPC_CMD_GET_DEVICE_MEMORY:
1724+
response_data = &std::get<rpc_msg_get_device_memory_rsp>(task.rsp);
1725+
response_size = sizeof(rpc_msg_get_device_memory_rsp);
1726+
break;
1727+
case RPC_CMD_BUFFER_CLEAR:
1728+
case RPC_CMD_FREE_BUFFER:
1729+
// No response data for this command
1730+
response_size = 0;
1731+
break;
1732+
default:
1733+
response_size = 0;
1734+
}
1735+
return send_msg(task.sockfd, response_data, response_size);
1736+
}
1737+
1738+
void* process_queue(rpc_worker_context* ctx) {
1739+
rpc_queue_t* queue = ctx->queue.get();
1740+
rpc_server server(ctx->backend);
1741+
while (queue->running) {
1742+
std::unique_lock<std::mutex> lock(queue->mutex);
1743+
// ReSharper disable once CppDFAConstantConditions
1744+
queue->cond.wait(lock, [queue] { return !queue->tasks.empty() || !queue->running; });
1745+
1746+
if (queue->tasks.empty()) {
1747+
break;
1748+
}
1749+
rpc_task_t &task = ctx->queue->tasks.front();
1750+
queue->tasks.pop();
1751+
lock.unlock();
1752+
switch (task.cmd) {
1753+
case RPC_CMD_ALLOC_BUFFER:
1754+
server.alloc_buffer(std::get<rpc_msg_alloc_buffer_req>(task.req), std::get<rpc_msg_alloc_buffer_rsp>(task.rsp));
1755+
std::get<bool>(task.rsp) = true;
1756+
break;
1757+
case RPC_CMD_GET_ALIGNMENT:
1758+
server.get_alignment(std::get<rpc_msg_get_alignment_rsp>(task.rsp));
1759+
std::get<bool>(task.rsp) = true;
1760+
break;
1761+
case RPC_CMD_GET_MAX_SIZE:
1762+
server.get_max_size(std::get<rpc_msg_get_max_size_rsp>(task.rsp));
1763+
std::get<bool>(task.rsp) = true;
1764+
break;
1765+
case RPC_CMD_BUFFER_GET_BASE:
1766+
std::get<bool>(task.rsp) = server.buffer_get_base(std::get<rpc_msg_buffer_get_base_req>(task.req), std::get<rpc_msg_buffer_get_base_rsp>(task.rsp));
1767+
break;
1768+
case RPC_CMD_FREE_BUFFER:
1769+
std::get<bool>(task.rsp) = server.free_buffer(std::get<rpc_msg_free_buffer_req>(task.req));
1770+
break;
1771+
case RPC_CMD_BUFFER_CLEAR:
1772+
std::get<bool>(task.rsp)= server.buffer_clear(std::get<rpc_msg_buffer_clear_req>(task.req));
1773+
break;
1774+
case RPC_CMD_SET_TENSOR:
1775+
std::get<bool>(task.rsp) = server.set_tensor(std::get<std::vector<uint8_t>>(task.req));
1776+
break;
1777+
case RPC_CMD_GET_TENSOR:
1778+
std::get<bool>(task.rsp) = server.get_tensor(std::get<rpc_msg_get_tensor_req>(task.req), std::get<std::vector<uint8_t>>(task.rsp));
1779+
break;
1780+
case RPC_CMD_COPY_TENSOR:
1781+
std::get<bool>(task.rsp) = server.copy_tensor(std::get<rpc_msg_copy_tensor_req>(task.req), std::get<rpc_msg_copy_tensor_rsp>(task.rsp));
1782+
break;
1783+
case RPC_CMD_GRAPH_COMPUTE:
1784+
std::get<bool>(task.rsp) = server.graph_compute(std::get<std::vector<uint8_t>>(task.req), std::get<rpc_msg_graph_compute_rsp>(task.rsp));
1785+
break;
1786+
case RPC_CMD_GET_DEVICE_MEMORY:
1787+
std::get<rpc_msg_get_device_memory_rsp>(task.rsp).free_mem = ctx->free_mem;
1788+
std::get<rpc_msg_get_device_memory_rsp>(task.rsp).total_mem = ctx->total_mem;
1789+
std::get<bool>(task.rsp) = true;
1790+
break;
1791+
case RPC_CMD_INIT_TENSOR:
1792+
std::get<bool>(task.rsp) = server.init_tensor(std::get<rpc_msg_init_tensor_req>(task.req));
1793+
break;
1794+
case RPC_CMD_GET_ALLOC_SIZE:
1795+
std::get<bool>(task.rsp) = server.get_alloc_size(std::get<rpc_msg_get_alloc_size_req>(task.req), std::get<rpc_msg_get_alloc_size_rsp>(task.rsp));
1796+
break;
1797+
default:
1798+
std::get<bool>(task.rsp) = false;
1799+
break;
1800+
}
1801+
1802+
std::lock_guard<std::mutex> response_lock(task.response_mutex);
1803+
if (!send_response(task)) {
1804+
std::get<bool>(task.rsp) = true;
1805+
}
1806+
}
1807+
return nullptr;
1808+
}
1809+
1810+
#endif
14981811
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
14991812
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
15001813

@@ -1515,9 +1828,7 @@ ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
15151828
/* .reg = */ ggml_backend_rpc_reg(),
15161829
/* .context = */ ctx,
15171830
};
1518-
15191831
dev_map[endpoint] = dev;
1520-
15211832
return dev;
15221833
}
15231834

0 commit comments

Comments
 (0)