1+ #include " ./ggml-rpc-queue-server.h"
2+ #ifdef RPC_QUEUE_SERVER
3+ #include < variant>
4+ #include < vector>
5+
6+ #include " ./rpc_cmd.h"
7+ #include " ./rpc_msg.h"
8+ #include " ./ggml-rpc-server.h"
9+ #include " ggml-backend.h"
10+ typedef int sockfd_t ;
11+ bool recv_msg (sockfd_t sockfd, void * msg, size_t msg_size);
12+ bool send_msg (sockfd_t sockfd, const void * msg, size_t msg_size);
13+ bool recv_data (sockfd_t sockfd, void * data, size_t size);
14+ bool recv_msg (sockfd_t sockfd, std::vector<uint8_t > & input);
15+ struct rpc_server_task_t {
16+ rpc_cmd cmd;
17+ typedef std::variant<rpc_msg_alloc_buffer_req, rpc_msg_get_alloc_size_req,
18+ rpc_msg_buffer_get_base_req, rpc_msg_free_buffer_req,
19+ rpc_msg_buffer_clear_req, std::vector<uint8_t >,
20+ rpc_msg_get_tensor_req, rpc_msg_copy_tensor_req,
21+ rpc_msg_init_tensor_req> req_t ;
22+ req_t req;
23+ std::variant<rpc_msg_alloc_buffer_rsp, rpc_msg_get_alloc_size_rsp,
24+ rpc_msg_get_alignment_rsp, rpc_msg_get_max_size_rsp,
25+ rpc_msg_buffer_get_base_rsp, std::vector<uint8_t >,
26+ rpc_msg_copy_tensor_rsp, rpc_msg_graph_compute_rsp,
27+ rpc_msg_get_device_memory_rsp, bool > rsp;
28+ sockfd_t sockfd;
29+ std::mutex response_mutex;
30+ rpc_server_task_t (rpc_server_task_t && t) : cmd(t.cmd), req(t.req), rsp(t.rsp), sockfd(t.sockfd) {}
31+ rpc_server_task_t (rpc_cmd cmd, req_t req, sockfd_t sockfd) : cmd(cmd), req(req), sockfd(sockfd) {}
32+ };
33+
34+ struct rpc_server_worker_context {
35+ std::shared_ptr<rpc_queue_t <rpc_server_task_t >> queue;
36+ ggml_backend_t backend;
37+ size_t free_mem;
38+ size_t total_mem;
39+ };
40+
41+ void process_server_queue (rpc_server_worker_context * ctx);
42+
43+ bool send_response (const rpc_server_task_t & task);
44+ bool send_response (const rpc_server_task_t & task) {
45+ size_t response_size = 0 ;
46+ const void * response_data = nullptr ;
47+
48+ switch (task.cmd ) {
49+ case rpc_cmd::RPC_CMD_ALLOC_BUFFER:
50+ response_data = &std::get<rpc_msg_alloc_buffer_rsp>(task.rsp );
51+ response_size = sizeof (rpc_msg_alloc_buffer_rsp);
52+ break ;
53+ case RPC_CMD_GET_ALIGNMENT:
54+ response_data = &std::get<rpc_msg_get_alignment_rsp>(task.rsp );
55+ response_size = sizeof (rpc_msg_get_alignment_rsp);
56+ break ;
57+ case RPC_CMD_GET_MAX_SIZE:
58+ response_data = &std::get<rpc_msg_get_max_size_rsp>(task.rsp );
59+ response_size = sizeof (rpc_msg_get_max_size_rsp);
60+ break ;
61+ case RPC_CMD_BUFFER_GET_BASE:
62+ response_data = &std::get<rpc_msg_buffer_get_base_rsp>(task.rsp );
63+ response_size = sizeof (rpc_msg_buffer_get_base_rsp);
64+ break ;
65+ case RPC_CMD_COPY_TENSOR:
66+ response_data = &std::get<rpc_msg_copy_tensor_rsp>(task.rsp );
67+ response_size = sizeof (rpc_msg_copy_tensor_rsp);
68+ break ;
69+ case RPC_CMD_INIT_TENSOR:
70+ case RPC_CMD_SET_TENSOR:
71+ response_data = &std::get<bool >(task.rsp );
72+ response_size = 1 ;
73+ break ;
74+ case RPC_CMD_GET_TENSOR:
75+ response_data = std::get<std::vector<uint8_t >>(task.rsp ).data ();
76+ response_size = std::get<std::vector<uint8_t >>(task.rsp ).size ();
77+ break ;
78+ case RPC_CMD_GRAPH_COMPUTE:
79+ response_data = &std::get<rpc_msg_graph_compute_rsp>(task.rsp ).result ;
80+ response_size = sizeof (rpc_msg_graph_compute_rsp::result);
81+ break ;
82+ case RPC_CMD_GET_DEVICE_MEMORY:
83+ response_data = &std::get<rpc_msg_get_device_memory_rsp>(task.rsp );
84+ response_size = sizeof (rpc_msg_get_device_memory_rsp);
85+ break ;
86+ case RPC_CMD_BUFFER_CLEAR:
87+ case RPC_CMD_FREE_BUFFER:
88+ // No response data for this command
89+ response_size = 0 ;
90+ break ;
91+ default :
92+ response_size = 0 ;
93+ }
94+ return send_msg (task.sockfd , response_data, response_size);
95+ }
96+
97+ void process_server_queue (rpc_server_worker_context * ctx) {
98+ rpc_queue_t <rpc_server_task_t >* queue = ctx->queue .get ();
99+ rpc_server server (ctx->backend );
100+ while (queue->running ) {
101+ std::unique_lock<std::mutex> lock (queue->mutex );
102+ // ReSharper disable once CppDFAConstantConditions
103+ queue->cond .wait (lock, [queue] { return !queue->tasks .empty () || !queue->running ; });
104+
105+ if (queue->tasks .empty ()) {
106+ break ;
107+ }
108+ rpc_server_task_t &task = ctx->queue ->tasks .front ();
109+ queue->tasks .pop ();
110+ lock.unlock ();
111+ switch (task.cmd ) {
112+ case RPC_CMD_ALLOC_BUFFER:
113+ server.alloc_buffer (std::get<rpc_msg_alloc_buffer_req>(task.req ), std::get<rpc_msg_alloc_buffer_rsp>(task.rsp ));
114+ std::get<bool >(task.rsp ) = true ;
115+ break ;
116+ case RPC_CMD_GET_ALIGNMENT:
117+ server.get_alignment (std::get<rpc_msg_get_alignment_rsp>(task.rsp ));
118+ std::get<bool >(task.rsp ) = true ;
119+ break ;
120+ case RPC_CMD_GET_MAX_SIZE:
121+ server.get_max_size (std::get<rpc_msg_get_max_size_rsp>(task.rsp ));
122+ std::get<bool >(task.rsp ) = true ;
123+ break ;
124+ case RPC_CMD_BUFFER_GET_BASE:
125+ std::get<bool >(task.rsp ) = server.buffer_get_base (std::get<rpc_msg_buffer_get_base_req>(task.req ), std::get<rpc_msg_buffer_get_base_rsp>(task.rsp ));
126+ break ;
127+ case RPC_CMD_FREE_BUFFER:
128+ std::get<bool >(task.rsp ) = server.free_buffer (std::get<rpc_msg_free_buffer_req>(task.req ));
129+ break ;
130+ case RPC_CMD_BUFFER_CLEAR:
131+ std::get<bool >(task.rsp )= server.buffer_clear (std::get<rpc_msg_buffer_clear_req>(task.req ));
132+ break ;
133+ case RPC_CMD_SET_TENSOR:
134+ std::get<bool >(task.rsp ) = server.set_tensor (std::get<std::vector<uint8_t >>(task.req ));
135+ break ;
136+ case RPC_CMD_GET_TENSOR:
137+ std::get<bool >(task.rsp ) = server.get_tensor (std::get<rpc_msg_get_tensor_req>(task.req ), std::get<std::vector<uint8_t >>(task.rsp ));
138+ break ;
139+ case RPC_CMD_COPY_TENSOR:
140+ std::get<bool >(task.rsp ) = server.copy_tensor (std::get<rpc_msg_copy_tensor_req>(task.req ), std::get<rpc_msg_copy_tensor_rsp>(task.rsp ));
141+ break ;
142+ case RPC_CMD_GRAPH_COMPUTE:
143+ std::get<bool >(task.rsp ) = server.graph_compute (std::get<std::vector<uint8_t >>(task.req ), std::get<rpc_msg_graph_compute_rsp>(task.rsp ));
144+ break ;
145+ case RPC_CMD_GET_DEVICE_MEMORY:
146+ std::get<rpc_msg_get_device_memory_rsp>(task.rsp ).free_mem = ctx->free_mem ;
147+ std::get<rpc_msg_get_device_memory_rsp>(task.rsp ).total_mem = ctx->total_mem ;
148+ std::get<bool >(task.rsp ) = true ;
149+ break ;
150+ case RPC_CMD_INIT_TENSOR:
151+ std::get<bool >(task.rsp ) = server.init_tensor (std::get<rpc_msg_init_tensor_req>(task.req ));
152+ break ;
153+ case RPC_CMD_GET_ALLOC_SIZE:
154+ std::get<bool >(task.rsp ) = server.get_alloc_size (std::get<rpc_msg_get_alloc_size_req>(task.req ), std::get<rpc_msg_get_alloc_size_rsp>(task.rsp ));
155+ break ;
156+ default :
157+ std::get<bool >(task.rsp ) = false ;
158+ break ;
159+ }
160+
161+ std::lock_guard<std::mutex> response_lock (task.response_mutex );
162+ if (!send_response (task)) {
163+ std::get<bool >(task.rsp ) = true ;
164+ }
165+ }
166+ }
167+
168+ // if you change, then do synchronize change with such name function (ggml-rpc.cpp)
169+ static void rpc_serve_client (ggml_backend_t backend, sockfd_t sockfd, std::shared_ptr<rpc_queue_t <rpc_server_task_t >> _queue) {
170+ auto queue = _queue.get ();
171+ rpc_server server (backend);
172+ while (true ) {
173+ rpc_cmd cmd;
174+ if (!recv_data (sockfd, &cmd, 1 )) {
175+ break ;
176+ }
177+ if (cmd >= RPC_CMD_COUNT) {
178+ // fail fast if the command is invalid
179+ fprintf (stderr, " Unknown command: %d\n " , cmd);
180+ break ;
181+ }
182+ rpc_server_task_t ::req_t req;
183+ switch (cmd) {
184+ case RPC_CMD_ALLOC_BUFFER: {
185+ rpc_msg_alloc_buffer_req request;
186+ if (!recv_msg (sockfd, &request, sizeof (request))) {
187+ return ;
188+ }
189+ req = request;
190+ break ;
191+ }
192+ case RPC_CMD_GET_ALLOC_SIZE: {
193+ rpc_msg_get_alloc_size_req request;
194+ if (!recv_msg (sockfd, &request, sizeof (request))) {
195+ return ;
196+ }
197+ req = request;
198+ break ;
199+ }
200+ case RPC_CMD_GET_ALIGNMENT: {
201+ if (!recv_msg (sockfd, nullptr , 0 )) {
202+ return ;
203+ }
204+ break ;
205+ }
206+ case RPC_CMD_GET_MAX_SIZE: {
207+ if (!recv_msg (sockfd, nullptr , 0 )) {
208+ return ;
209+ }
210+ break ;
211+ }
212+ case RPC_CMD_BUFFER_GET_BASE: {
213+ rpc_msg_buffer_get_base_req request;
214+ if (!recv_msg (sockfd, &request, sizeof (request))) {
215+ return ;
216+ }
217+ req = request;
218+ break ;
219+ }
220+ case RPC_CMD_FREE_BUFFER: {
221+ rpc_msg_free_buffer_req request;
222+ if (!recv_msg (sockfd, &request, sizeof (request))) {
223+ return ;
224+ }
225+ req = request;
226+ break ;
227+ }
228+ case RPC_CMD_BUFFER_CLEAR: {
229+ rpc_msg_buffer_clear_req request;
230+ if (!recv_msg (sockfd, &request, sizeof (request))) {
231+ return ;
232+ }
233+ req = request;
234+ break ;
235+ }
236+ case RPC_CMD_SET_TENSOR: {
237+ std::vector<uint8_t > input;
238+ if (!recv_msg (sockfd, input)) {
239+ return ;
240+ }
241+ req = input;
242+ break ;
243+ }
244+ case RPC_CMD_INIT_TENSOR: {
245+ rpc_msg_init_tensor_req request;
246+ if (!recv_msg (sockfd, &request,sizeof (request))) {
247+ return ;
248+ }
249+ req = request;
250+ break ;
251+ }
252+ case RPC_CMD_GET_TENSOR: {
253+ rpc_msg_get_tensor_req request;
254+ if (!recv_msg (sockfd, &request, sizeof (request))) {
255+ return ;
256+ }
257+ req = request;
258+ break ;
259+ }
260+ case RPC_CMD_COPY_TENSOR: {
261+ rpc_msg_copy_tensor_req request;
262+ if (!recv_msg (sockfd, &request, sizeof (request))) {
263+ return ;
264+ }
265+ req = request;
266+ break ;
267+ }
268+ case RPC_CMD_GRAPH_COMPUTE: {
269+ std::vector<uint8_t > input;
270+ if (!recv_msg (sockfd, input)) {
271+ return ;
272+ }
273+ req = input;
274+ break ;
275+ }
276+ case RPC_CMD_GET_DEVICE_MEMORY: {
277+ if (!recv_msg (sockfd, nullptr , 0 )) {
278+ return ;
279+ }
280+ break ;
281+ }
282+ default : {
283+ fprintf (stderr, " Unknown command: %d\n " , cmd);
284+ return ;
285+ }
286+ }
287+ std::lock_guard<std::mutex> lock (queue->mutex );
288+ queue->tasks .push (rpc_server_task_t {cmd, req, sockfd});
289+ queue->cond .notify_one ();
290+ }
291+ }
292+ #endif
0 commit comments