2626# include < unistd.h>
2727#endif
2828#include < cstring>
29+ #include < variant>
30+
31+ #define RPC_QUEUE
32+ #ifdef RPC_QUEUE
33+ #include < condition_variable>
34+ #include < functional>
35+ #include < queue>
36+ #include < thread>
37+ #endif
2938
3039#ifdef _WIN32
3140typedef SOCKET sockfd_t ;
@@ -160,6 +169,42 @@ struct rpc_msg_get_device_memory_rsp {
160169};
161170#pragma pack(pop)
162171
172+ #ifdef RPC_QUEUE
173+ struct rpc_task_t {
174+ rpc_cmd cmd;
175+ typedef std::variant<rpc_msg_alloc_buffer_req, rpc_msg_get_alloc_size_req,
176+ rpc_msg_buffer_get_base_req, rpc_msg_free_buffer_req,
177+ rpc_msg_buffer_clear_req, std::vector<uint8_t >,
178+ rpc_msg_get_tensor_req, rpc_msg_copy_tensor_req,
179+ rpc_msg_init_tensor_req> req_t ;
180+ req_t req;
181+ std::variant<rpc_msg_alloc_buffer_rsp, rpc_msg_get_alloc_size_rsp,
182+ rpc_msg_get_alignment_rsp, rpc_msg_get_max_size_rsp,
183+ rpc_msg_buffer_get_base_rsp, std::vector<uint8_t >,
184+ rpc_msg_copy_tensor_rsp, rpc_msg_graph_compute_rsp,
185+ rpc_msg_get_device_memory_rsp, bool > rsp;
186+ sockfd_t sockfd;
187+ std::mutex response_mutex;
188+ rpc_task_t (rpc_task_t && t) : cmd(t.cmd), req(t.req), rsp(t.rsp), sockfd(t.sockfd) {}
189+ rpc_task_t (rpc_cmd cmd, req_t req, sockfd_t sockfd) : cmd(cmd), req(req), sockfd(sockfd) {}
190+ };
191+
192+ struct rpc_queue_t {
193+ std::queue<rpc_task_t > tasks;
194+ std::mutex mutex;
195+ std::condition_variable cond;
196+ volatile bool running;
197+ };
198+
199+ struct rpc_worker_context {
200+ std::shared_ptr<rpc_queue_t > queue;
201+ ggml_backend_t backend;
202+ size_t free_mem;
203+ size_t total_mem;
204+ };
205+ void * process_queue (rpc_worker_context* ctx);
206+ #endif
207+
163208// RPC data structures
164209
165210static ggml_guid_t ggml_backend_rpc_guid () {
@@ -1125,7 +1170,8 @@ rpc_server::~rpc_server() {
11251170 ggml_backend_buffer_free (buffer);
11261171 }
11271172}
1128-
1173+ #ifndef RPC_QUEUE
1174+ // if you change, then do synchronize change with such name function
11291175static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
11301176 rpc_server server (backend);
11311177 while (true ) {
@@ -1312,7 +1358,133 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13121358 }
13131359 }
13141360}
1361+ #else
13151362
1363+ // if you change, then do synchronize change with such name function
1364+ static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, std::shared_ptr<rpc_queue_t > _queue) {
1365+ auto queue = _queue.get ();
1366+ rpc_server server (backend);
1367+ while (true ) {
1368+ rpc_cmd cmd;
1369+ if (!recv_data (sockfd, &cmd, 1 )) {
1370+ break ;
1371+ }
1372+ if (cmd >= RPC_CMD_COUNT) {
1373+ // fail fast if the command is invalid
1374+ fprintf (stderr, " Unknown command: %d\n " , cmd);
1375+ break ;
1376+ }
1377+ rpc_task_t ::req_t req;
1378+ switch (cmd) {
1379+ case RPC_CMD_ALLOC_BUFFER: {
1380+ rpc_msg_alloc_buffer_req request;
1381+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1382+ return ;
1383+ }
1384+ req = request;
1385+ break ;
1386+ }
1387+ case RPC_CMD_GET_ALLOC_SIZE: {
1388+ rpc_msg_get_alloc_size_req request;
1389+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1390+ return ;
1391+ }
1392+ req = request;
1393+ break ;
1394+ }
1395+ case RPC_CMD_GET_ALIGNMENT: {
1396+ if (!recv_msg (sockfd, nullptr , 0 )) {
1397+ return ;
1398+ }
1399+ break ;
1400+ }
1401+ case RPC_CMD_GET_MAX_SIZE: {
1402+ if (!recv_msg (sockfd, nullptr , 0 )) {
1403+ return ;
1404+ }
1405+ break ;
1406+ }
1407+ case RPC_CMD_BUFFER_GET_BASE: {
1408+ rpc_msg_buffer_get_base_req request;
1409+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1410+ return ;
1411+ }
1412+ req = request;
1413+ break ;
1414+ }
1415+ case RPC_CMD_FREE_BUFFER: {
1416+ rpc_msg_free_buffer_req request;
1417+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1418+ return ;
1419+ }
1420+ req = request;
1421+ break ;
1422+ }
1423+ case RPC_CMD_BUFFER_CLEAR: {
1424+ rpc_msg_buffer_clear_req request;
1425+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1426+ return ;
1427+ }
1428+ req = request;
1429+ break ;
1430+ }
1431+ case RPC_CMD_SET_TENSOR: {
1432+ std::vector<uint8_t > input;
1433+ if (!recv_msg (sockfd, input)) {
1434+ return ;
1435+ }
1436+ req = input;
1437+ break ;
1438+ }
1439+ case RPC_CMD_INIT_TENSOR: {
1440+ rpc_msg_init_tensor_req request;
1441+ if (!recv_msg (sockfd, &request,sizeof (request))) {
1442+ return ;
1443+ }
1444+ req = request;
1445+ break ;
1446+ }
1447+ case RPC_CMD_GET_TENSOR: {
1448+ rpc_msg_get_tensor_req request;
1449+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1450+ return ;
1451+ }
1452+ req = request;
1453+ break ;
1454+ }
1455+ case RPC_CMD_COPY_TENSOR: {
1456+ rpc_msg_copy_tensor_req request;
1457+ if (!recv_msg (sockfd, &request, sizeof (request))) {
1458+ return ;
1459+ }
1460+ req = request;
1461+ break ;
1462+ }
1463+ case RPC_CMD_GRAPH_COMPUTE: {
1464+ std::vector<uint8_t > input;
1465+ if (!recv_msg (sockfd, input)) {
1466+ return ;
1467+ }
1468+ req = input;
1469+ break ;
1470+ }
1471+ case RPC_CMD_GET_DEVICE_MEMORY: {
1472+ if (!recv_msg (sockfd, nullptr , 0 )) {
1473+ return ;
1474+ }
1475+ break ;
1476+ }
1477+ default : {
1478+ fprintf (stderr, " Unknown command: %d\n " , cmd);
1479+ return ;
1480+ }
1481+ }
1482+ std::lock_guard<std::mutex> lock (queue->mutex );
1483+ queue->tasks .push (rpc_task_t {cmd, req, sockfd});
1484+ queue->cond .notify_one ();
1485+ }
1486+ }
1487+ #endif
13161488void ggml_backend_rpc_start_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
13171489 std::string host;
13181490 int port;
@@ -1334,6 +1506,12 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13341506 fprintf (stderr, " Failed to create server socket\n " );
13351507 return ;
13361508 }
1509+
1510+ #ifdef RPC_QUEUE
1511+ std::shared_ptr<rpc_queue_t > queue;
1512+ queue->running = true ;
1513+ std::thread worker_thread (process_queue, new rpc_worker_context{queue, backend, free_mem, total_mem});
1514+ #endif
13371515 while (true ) {
13381516 auto client_socket = socket_accept (server_socket->fd );
13391517 if (client_socket == nullptr ) {
@@ -1342,10 +1520,18 @@ void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint
13421520 }
13431521 printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
13441522 fflush (stdout);
1523+ #ifndef RPC_QUEUE
13451524 rpc_serve_client (backend, client_socket->fd , free_mem, total_mem);
1525+ #else
1526+ rpc_serve_client (backend, client_socket->fd , queue);
1527+ #endif
13461528 printf (" Client connection closed\n " );
13471529 fflush (stdout);
13481530 }
1531+ queue->running = false ;
1532+ #ifdef RPC_QUEUE
1533+ worker_thread.join ();
1534+ #endif
13491535#ifdef _WIN32
13501536 WSACleanup ();
13511537#endif
@@ -1494,7 +1680,134 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
14941680
14951681 return &ggml_backend_rpc_reg;
14961682}
1683+ #ifdef RPC_QUEUE
1684+ bool send_response (const rpc_task_t & task);
1685+ bool send_response (const rpc_task_t & task) {
1686+ size_t response_size = 0 ;
1687+ const void * response_data = nullptr ;
14971688
1689+ switch (task.cmd ) {
1690+ case RPC_CMD_ALLOC_BUFFER:
1691+ response_data = &std::get<rpc_msg_alloc_buffer_rsp>(task.rsp );
1692+ response_size = sizeof (rpc_msg_alloc_buffer_rsp);
1693+ break ;
1694+ case RPC_CMD_GET_ALIGNMENT:
1695+ response_data = &std::get<rpc_msg_get_alignment_rsp>(task.rsp );
1696+ response_size = sizeof (rpc_msg_get_alignment_rsp);
1697+ break ;
1698+ case RPC_CMD_GET_MAX_SIZE:
1699+ response_data = &std::get<rpc_msg_get_max_size_rsp>(task.rsp );
1700+ response_size = sizeof (rpc_msg_get_max_size_rsp);
1701+ break ;
1702+ case RPC_CMD_BUFFER_GET_BASE:
1703+ response_data = &std::get<rpc_msg_buffer_get_base_rsp>(task.rsp );
1704+ response_size = sizeof (rpc_msg_buffer_get_base_rsp);
1705+ break ;
1706+ case RPC_CMD_COPY_TENSOR:
1707+ response_data = &std::get<rpc_msg_copy_tensor_rsp>(task.rsp );
1708+ response_size = sizeof (rpc_msg_copy_tensor_rsp);
1709+ break ;
1710+ case RPC_CMD_INIT_TENSOR:
1711+ case RPC_CMD_SET_TENSOR:
1712+ response_data = &std::get<bool >(task.rsp );
1713+ response_size = 1 ;
1714+ break ;
1715+ case RPC_CMD_GET_TENSOR:
1716+ response_data = std::get<std::vector<uint8_t >>(task.rsp ).data ();
1717+ response_size = std::get<std::vector<uint8_t >>(task.rsp ).size ();
1718+ break ;
1719+ case RPC_CMD_GRAPH_COMPUTE:
1720+ response_data = &std::get<rpc_msg_graph_compute_rsp>(task.rsp ).result ;
1721+ response_size = sizeof (rpc_msg_graph_compute_rsp::result);
1722+ break ;
1723+ case RPC_CMD_GET_DEVICE_MEMORY:
1724+ response_data = &std::get<rpc_msg_get_device_memory_rsp>(task.rsp );
1725+ response_size = sizeof (rpc_msg_get_device_memory_rsp);
1726+ break ;
1727+ case RPC_CMD_BUFFER_CLEAR:
1728+ case RPC_CMD_FREE_BUFFER:
1729+ // No response data for this command
1730+ response_size = 0 ;
1731+ break ;
1732+ default :
1733+ response_size = 0 ;
1734+ }
1735+ return send_msg (task.sockfd , response_data, response_size);
1736+ }
1737+
1738+ void * process_queue (rpc_worker_context* ctx) {
1739+ rpc_queue_t * queue = ctx->queue .get ();
1740+ rpc_server server (ctx->backend );
1741+ while (queue->running ) {
1742+ std::unique_lock<std::mutex> lock (queue->mutex );
1743+ // ReSharper disable once CppDFAConstantConditions
1744+ queue->cond .wait (lock, [queue] { return !queue->tasks .empty () || !queue->running ; });
1745+
1746+ if (queue->tasks .empty ()) {
1747+ break ;
1748+ }
1749+ rpc_task_t &task = ctx->queue ->tasks .front ();
1750+ queue->tasks .pop ();
1751+ lock.unlock ();
1752+ switch (task.cmd ) {
1753+ case RPC_CMD_ALLOC_BUFFER:
1754+ server.alloc_buffer (std::get<rpc_msg_alloc_buffer_req>(task.req ), std::get<rpc_msg_alloc_buffer_rsp>(task.rsp ));
1755+ std::get<bool >(task.rsp ) = true ;
1756+ break ;
1757+ case RPC_CMD_GET_ALIGNMENT:
1758+ server.get_alignment (std::get<rpc_msg_get_alignment_rsp>(task.rsp ));
1759+ std::get<bool >(task.rsp ) = true ;
1760+ break ;
1761+ case RPC_CMD_GET_MAX_SIZE:
1762+ server.get_max_size (std::get<rpc_msg_get_max_size_rsp>(task.rsp ));
1763+ std::get<bool >(task.rsp ) = true ;
1764+ break ;
1765+ case RPC_CMD_BUFFER_GET_BASE:
1766+ std::get<bool >(task.rsp ) = server.buffer_get_base (std::get<rpc_msg_buffer_get_base_req>(task.req ), std::get<rpc_msg_buffer_get_base_rsp>(task.rsp ));
1767+ break ;
1768+ case RPC_CMD_FREE_BUFFER:
1769+ std::get<bool >(task.rsp ) = server.free_buffer (std::get<rpc_msg_free_buffer_req>(task.req ));
1770+ break ;
1771+ case RPC_CMD_BUFFER_CLEAR:
1772+ std::get<bool >(task.rsp )= server.buffer_clear (std::get<rpc_msg_buffer_clear_req>(task.req ));
1773+ break ;
1774+ case RPC_CMD_SET_TENSOR:
1775+ std::get<bool >(task.rsp ) = server.set_tensor (std::get<std::vector<uint8_t >>(task.req ));
1776+ break ;
1777+ case RPC_CMD_GET_TENSOR:
1778+ std::get<bool >(task.rsp ) = server.get_tensor (std::get<rpc_msg_get_tensor_req>(task.req ), std::get<std::vector<uint8_t >>(task.rsp ));
1779+ break ;
1780+ case RPC_CMD_COPY_TENSOR:
1781+ std::get<bool >(task.rsp ) = server.copy_tensor (std::get<rpc_msg_copy_tensor_req>(task.req ), std::get<rpc_msg_copy_tensor_rsp>(task.rsp ));
1782+ break ;
1783+ case RPC_CMD_GRAPH_COMPUTE:
1784+ std::get<bool >(task.rsp ) = server.graph_compute (std::get<std::vector<uint8_t >>(task.req ), std::get<rpc_msg_graph_compute_rsp>(task.rsp ));
1785+ break ;
1786+ case RPC_CMD_GET_DEVICE_MEMORY:
1787+ std::get<rpc_msg_get_device_memory_rsp>(task.rsp ).free_mem = ctx->free_mem ;
1788+ std::get<rpc_msg_get_device_memory_rsp>(task.rsp ).total_mem = ctx->total_mem ;
1789+ std::get<bool >(task.rsp ) = true ;
1790+ break ;
1791+ case RPC_CMD_INIT_TENSOR:
1792+ std::get<bool >(task.rsp ) = server.init_tensor (std::get<rpc_msg_init_tensor_req>(task.req ));
1793+ break ;
1794+ case RPC_CMD_GET_ALLOC_SIZE:
1795+ std::get<bool >(task.rsp ) = server.get_alloc_size (std::get<rpc_msg_get_alloc_size_req>(task.req ), std::get<rpc_msg_get_alloc_size_rsp>(task.rsp ));
1796+ break ;
1797+ default :
1798+ std::get<bool >(task.rsp ) = false ;
1799+ break ;
1800+ }
1801+
1802+ std::lock_guard<std::mutex> response_lock (task.response_mutex );
1803+ if (!send_response (task)) {
1804+ std::get<bool >(task.rsp ) = true ;
1805+ }
1806+ }
1807+ return nullptr ;
1808+ }
1809+
1810+ #endif
14981811ggml_backend_dev_t ggml_backend_rpc_add_device (const char * endpoint) {
14991812 static std::unordered_map<std::string, ggml_backend_dev_t > dev_map;
15001813
@@ -1515,9 +1828,7 @@ ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
15151828 /* .reg = */ ggml_backend_rpc_reg (),
15161829 /* .context = */ ctx,
15171830 };
1518-
15191831 dev_map[endpoint] = dev;
1520-
15211832 return dev;
15221833}
15231834
0 commit comments