66#include < string>
77#include < vector>
88#include < memory>
9+ #include < mutex>
910#include < unordered_map>
1011#include < unordered_set>
1112#ifdef _WIN32
@@ -47,6 +48,7 @@ struct socket_t {
4748 sockfd_t fd;
4849 socket_t (sockfd_t fd) : fd(fd) {}
4950 ~socket_t () {
51+ GGML_PRINT_DEBUG (" [%s] closing socket %d\n " , __func__, this ->fd );
5052#ifdef _WIN32
5153 closesocket (this ->fd );
5254#else
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
9799}
98100
99101struct ggml_backend_rpc_buffer_type_context {
100- std::shared_ptr< socket_t > sock ;
102+ std::string endpoint ;
101103 std::string name;
102104 size_t alignment;
103105 size_t max_size;
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
106108struct ggml_backend_rpc_context {
107109 std::string endpoint;
108110 std::string name;
109- std::shared_ptr<socket_t > sock;
110- ggml_backend_buffer_type_t buft;
111111};
112112
113113struct ggml_backend_rpc_buffer_context {
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
231231 return true ;
232232}
233233
234- static bool parse_endpoint (const char * endpoint, std::string & host, int & port) {
235- std::string str (endpoint);
236- size_t pos = str.find (' :' );
234+ static bool parse_endpoint (const std::string & endpoint, std::string & host, int & port) {
235+ size_t pos = endpoint.find (' :' );
237236 if (pos == std::string::npos) {
238237 return false ;
239238 }
240- host = str .substr (0 , pos);
241- port = std::stoi (str .substr (pos + 1 ));
239+ host = endpoint .substr (0 , pos);
240+ port = std::stoi (endpoint .substr (pos + 1 ));
242241 return true ;
243242}
244243
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
273272
274273// RPC client-side implementation
275274
275+ static std::shared_ptr<socket_t > get_socket (const std::string & endpoint) {
276+ static std::mutex mutex;
277+ std::lock_guard<std::mutex> lock (mutex);
278+ static std::unordered_map<std::string, std::weak_ptr<socket_t >> sockets;
279+ static bool initialized = false ;
280+
281+ auto it = sockets.find (endpoint);
282+ if (it != sockets.end ()) {
283+ if (auto sock = it->second .lock ()) {
284+ return sock;
285+ }
286+ }
287+ std::string host;
288+ int port;
289+ if (!parse_endpoint (endpoint, host, port)) {
290+ return nullptr ;
291+ }
292+ #ifdef _WIN32
293+ if (!initialized) {
294+ WSADATA wsaData;
295+ int res = WSAStartup (MAKEWORD (2 , 2 ), &wsaData);
296+ if (res != 0 ) {
297+ return nullptr ;
298+ }
299+ initialized = true ;
300+ }
301+ #else
302+ UNUSED (initialized);
303+ #endif
304+ auto sock = socket_connect (host.c_str (), port);
305+ if (sock == nullptr ) {
306+ return nullptr ;
307+ }
308+ GGML_PRINT_DEBUG (" [%s] connected to %s, sockfd=%d\n " , __func__, endpoint.c_str (), sock->fd );
309+ sockets[endpoint] = sock;
310+ return sock;
311+ }
312+
276313GGML_CALL static const char * ggml_backend_rpc_buffer_get_name (ggml_backend_buffer_t buffer) {
277314 ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
278315 return ctx->name .c_str ();
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
442479 std::vector<uint8_t > input (input_size, 0 );
443480 memcpy (input.data (), &size, sizeof (size));
444481 std::vector<uint8_t > output;
445- bool status = send_rpc_cmd (buft_ctx->sock , ALLOC_BUFFER, input, output);
482+ auto sock = get_socket (buft_ctx->endpoint );
483+ bool status = send_rpc_cmd (sock, ALLOC_BUFFER, input, output);
446484 GGML_ASSERT (status);
447485 GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
448486 // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
453491 if (remote_ptr != 0 ) {
454492 ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
455493 ggml_backend_rpc_buffer_interface,
456- new ggml_backend_rpc_buffer_context{buft_ctx-> sock , {}, remote_ptr, " RPC" },
494+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, " RPC" },
457495 remote_size);
458496 return buffer;
459497 } else {
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
508546 }
509547 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
510548 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
511- return buft_ctx->sock == rpc_ctx->sock ;
549+ return buft_ctx->endpoint == rpc_ctx->endpoint ;
512550}
513551
514552static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
521559 /* .is_host = */ NULL ,
522560};
523561
524-
525562GGML_CALL static const char * ggml_backend_rpc_name (ggml_backend_t backend) {
526563 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
527564
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
530567
531568GGML_CALL static void ggml_backend_rpc_free (ggml_backend_t backend) {
532569 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
533- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft ->context ;
534- delete buft_ctx;
535- delete rpc_ctx->buft ;
536570 delete rpc_ctx;
537571 delete backend;
538572}
539573
540574GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type (ggml_backend_t backend) {
541575 ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context ;
542- return ctx->buft ;
576+ return ggml_backend_rpc_buffer_type ( ctx->endpoint . c_str ()) ;
543577}
544578
545579GGML_CALL static void ggml_backend_rpc_synchronize (ggml_backend_t backend) {
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
590624 std::vector<uint8_t > input;
591625 serialize_graph (cgraph, input);
592626 std::vector<uint8_t > output;
593- bool status = send_rpc_cmd (rpc_ctx->sock , GRAPH_COMPUTE, input, output);
627+ auto sock = get_socket (rpc_ctx->endpoint );
628+ bool status = send_rpc_cmd (sock, GRAPH_COMPUTE, input, output);
594629 GGML_ASSERT (status);
595630 GGML_ASSERT (output.size () == 1 );
596631 return (enum ggml_status)output[0 ];
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
624659 /* .event_synchronize = */ NULL ,
625660};
626661
627- static std::unordered_map<std::string, ggml_backend_t > instances;
628-
629662GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type (const char * endpoint) {
630- ggml_backend_t backend = ggml_backend_rpc_init (endpoint);
631- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type (backend) : nullptr ;
632- }
633-
634- GGML_CALL ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
635- std::string endpoint_str (endpoint);
636- if (instances.find (endpoint_str) != instances.end ()) {
637- return instances[endpoint_str];
638- }
639- #ifdef _WIN32
640- {
641- WSADATA wsaData;
642- int res = WSAStartup (MAKEWORD (2 , 2 ), &wsaData);
643- if (res != 0 ) {
644- return nullptr ;
645- }
646- }
647- #endif
648- fprintf (stderr, " Connecting to %s\n " , endpoint);
649- std::string host;
650- int port;
651- if (!parse_endpoint (endpoint, host, port)) {
652- return nullptr ;
653- }
654- auto sock = socket_connect (host.c_str (), port);
663+ static std::mutex mutex;
664+ std::lock_guard<std::mutex> lock (mutex);
665+ // NOTE: buffer types are allocated and never freed; this is by design
666+ static std::unordered_map<std::string, ggml_backend_buffer_type_t > buft_map;
667+ auto it = buft_map.find (endpoint);
668+ if (it != buft_map.end ()) {
669+ return it->second ;
670+ }
671+ auto sock = get_socket (endpoint);
655672 if (sock == nullptr ) {
656673 return nullptr ;
657674 }
658675 size_t alignment = get_alignment (sock);
659676 size_t max_size = get_max_size (sock);
660677 ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
661- /* .sock = */ sock ,
662- /* .name = */ " RPC" + std::to_string (sock-> fd ) ,
678+ /* .endpoint = */ endpoint ,
679+ /* .name = */ " RPC[ " + std::string (endpoint) + " ] " ,
663680 /* .alignment = */ alignment,
664- /* .max_size = */ max_size
681+ /* .max_size = */ max_size
665682 };
666683
667684 ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
668685 /* .iface = */ ggml_backend_rpc_buffer_type_interface,
669686 /* .context = */ buft_ctx
670687 };
688+ buft_map[endpoint] = buft;
689+ return buft;
690+ }
671691
692+ GGML_CALL ggml_backend_t ggml_backend_rpc_init (const char * endpoint) {
672693 ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
673- /* .endpoint = */ endpoint,
674- /* .name = */ " RPC" + std::to_string (sock->fd ),
675- /* .sock = */ sock,
676- /* .buft = */ buft
694+ /* .endpoint = */ endpoint,
695+ /* .name = */ " RPC" ,
677696 };
678697
679- instances[endpoint] = new ggml_backend {
698+ ggml_backend_t backend = new ggml_backend {
680699 /* .guid = */ ggml_backend_rpc_guid (),
681700 /* .interface = */ ggml_backend_rpc_interface,
682701 /* .context = */ ctx
683702 };
684-
685- return instances[endpoint];
703+ return backend;
686704}
687705
688706GGML_API GGML_CALL bool ggml_backend_is_rpc (ggml_backend_t backend) {
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
706724}
707725
708726GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory (const char * endpoint, size_t * free, size_t * total) {
709- ggml_backend_t backend = ggml_backend_rpc_init (endpoint);
710- if (backend == nullptr ) {
727+ auto sock = get_socket (endpoint);
728+ if (sock == nullptr ) {
711729 *free = 0 ;
712730 *total = 0 ;
713731 return ;
714732 }
715- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context ;
716- get_device_memory (ctx->sock , free, total);
733+ get_device_memory (sock, free, total);
717734}
718735
719736// RPC server-side implementation
0 commit comments