2525#  include  < netdb.h> 
2626#  include  < unistd.h> 
2727#endif 
28- #include  < string.h > 
28+ #include  < cstring > 
2929
3030#define  UNUSED  GGML_UNUSED
3131
@@ -630,22 +630,6 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
630630    return  (enum  ggml_status)output[0 ];
631631}
632632
633- static  bool  ggml_backend_rpc_supports_op (ggml_backend_t  backend, const  ggml_tensor * op) {
634-     UNUSED (backend);
635-     UNUSED (op);
636-     // TODO: call the remote backend and cache the results
637-     return  true ;
638- }
639- 
640- static  bool  ggml_backend_rpc_supports_buft (ggml_backend_t  backend, ggml_backend_buffer_type_t  buft) {
641-     if  (!buft || buft->iface .get_name  != ggml_backend_rpc_buffer_type_name) {
642-         return  false ;
643-     }
644-     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
645-     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
646-     return  buft_ctx->endpoint  == rpc_ctx->endpoint ;
647- }
648- 
649633static  ggml_backend_i ggml_backend_rpc_interface = {
650634    /*  .get_name                = */   ggml_backend_rpc_name,
651635    /*  .free                    = */   ggml_backend_rpc_free,
@@ -659,8 +643,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
659643    /*  .graph_plan_update       = */   NULL ,
660644    /*  .graph_plan_compute      = */   NULL ,
661645    /*  .graph_compute           = */   ggml_backend_rpc_graph_compute,
662-     /*  .supports_op             = */   ggml_backend_rpc_supports_op ,
663-     /*  .supports_buft           = */   ggml_backend_rpc_supports_buft ,
646+     /*  .supports_op             = */   NULL ,
647+     /*  .supports_buft           = */   NULL ,
664648    /*  .offload_op              = */   NULL ,
665649    /*  .event_record            = */   NULL ,
666650    /*  .event_wait              = */   NULL ,
@@ -691,7 +675,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
691675
692676    ggml_backend_buffer_type_t  buft = new  ggml_backend_buffer_type {
693677        /*  .iface   = */   ggml_backend_rpc_buffer_type_interface,
694-         /*  .device  = */   nullptr ,
678+         /*  .device  = */   ggml_backend_rpc_add_device (endpoint) ,
695679        /*  .context = */   buft_ctx
696680    };
697681    buft_map[endpoint] = buft;
@@ -707,7 +691,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
707691    ggml_backend_t  backend = new  ggml_backend {
708692        /*  .guid      = */   ggml_backend_rpc_guid (),
709693        /*  .interface = */   ggml_backend_rpc_interface,
710-         /*  .device    = */   nullptr ,
694+         /*  .device    = */   ggml_backend_rpc_add_device (endpoint) ,
711695        /*  .context   = */   ctx
712696    };
713697    return  backend;
@@ -1189,7 +1173,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
11891173    }
11901174}
11911175
1192- void  start_rpc_server (ggml_backend_t  backend, const  char  * endpoint, size_t  free_mem, size_t  total_mem) {
1176+ void  ggml_backend_rpc_start_server (ggml_backend_t  backend, const  char  * endpoint, size_t  free_mem, size_t  total_mem) {
11931177    std::string host;
11941178    int  port;
11951179    if  (!parse_endpoint (endpoint, host, port)) {
@@ -1226,3 +1210,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
12261210    WSACleanup ();
12271211#endif 
12281212}
1213+ 
1214+ //  device interface
1215+ 
1216+ struct  ggml_backend_rpc_device_context  {
1217+     std::string endpoint;
1218+     std::string name;
1219+ };
1220+ 
1221+ static  const  char  * ggml_backend_rpc_device_get_name (ggml_backend_dev_t  dev) {
1222+     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1223+ 
1224+     return  ctx->name .c_str ();
1225+ }
1226+ 
1227+ static  const  char  * ggml_backend_rpc_device_get_description (ggml_backend_dev_t  dev) {
1228+     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1229+ 
1230+     return  ctx->name .c_str ();
1231+ }
1232+ 
1233+ static  void  ggml_backend_rpc_device_get_memory (ggml_backend_dev_t  dev, size_t  * free, size_t  * total) {
1234+     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1235+ 
1236+     ggml_backend_rpc_get_device_memory (ctx->endpoint .c_str (), free, total);
1237+ 
1238+     UNUSED (dev);
1239+ }
1240+ 
1241+ static  enum  ggml_backend_dev_type ggml_backend_rpc_device_get_type (ggml_backend_dev_t  dev) {
1242+     //  TODO: obtain value from the server
1243+     return  GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
1244+ 
1245+     UNUSED (dev);
1246+ }
1247+ 
1248+ static  void  ggml_backend_rpc_device_get_props (ggml_backend_dev_t  dev, struct  ggml_backend_dev_props  * props) {
1249+     props->name         = ggml_backend_rpc_device_get_name (dev);
1250+     props->description  = ggml_backend_rpc_device_get_description (dev);
1251+     props->type         = ggml_backend_rpc_device_get_type (dev);
1252+     ggml_backend_rpc_device_get_memory (dev, &props->memory_free , &props->memory_total );
1253+     props->caps  = {
1254+         /*  .async                 = */   false ,
1255+         /*  .host_buffer           = */   false ,
1256+         /*  .buffer_from_host_ptr  = */   false ,
1257+         /*  .events                = */   false ,
1258+     };
1259+ }
1260+ 
1261+ static  ggml_backend_t  ggml_backend_rpc_device_init (ggml_backend_dev_t  dev, const  char  * params) {
1262+     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1263+ 
1264+     return  ggml_backend_rpc_init (ctx->endpoint .c_str ());
1265+ 
1266+     UNUSED (params);
1267+ }
1268+ 
1269+ static  ggml_backend_buffer_type_t  ggml_backend_rpc_device_get_buffer_type (ggml_backend_dev_t  dev) {
1270+     ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1271+ 
1272+     return  ggml_backend_rpc_buffer_type (ctx->endpoint .c_str ());
1273+ 
1274+     UNUSED (dev);
1275+ }
1276+ 
1277+ static  ggml_backend_buffer_t  ggml_backend_rpc_device_buffer_from_ptr (ggml_backend_dev_t  dev, void  * ptr, size_t  size, size_t  max_tensor_size) {
1278+     return  ggml_backend_cpu_buffer_from_ptr (ptr, size);
1279+ 
1280+     UNUSED (dev);
1281+     UNUSED (max_tensor_size);
1282+ }
1283+ 
1284+ static  bool  ggml_backend_rpc_device_supports_op (ggml_backend_dev_t  dev, const  struct  ggml_tensor  * op) {
1285+     UNUSED (dev);
1286+     UNUSED (op);
1287+     // TODO: call the remote backend and cache the results
1288+     return  true ;
1289+ }
1290+ 
1291+ static  bool  ggml_backend_rpc_device_supports_buft (ggml_backend_dev_t  dev, ggml_backend_buffer_type_t  buft) {
1292+     if  (!buft || buft->iface .get_name  != ggml_backend_rpc_buffer_type_name) {
1293+         return  false ;
1294+     }
1295+     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
1296+     ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context ;
1297+     return  buft_ctx->endpoint  == dev_ctx->endpoint ;
1298+ }
1299+ 
1300+ static  const  struct  ggml_backend_device_i  ggml_backend_rpc_device_i = {
1301+     /*  .get_name             = */   ggml_backend_rpc_device_get_name,
1302+     /*  .get_description      = */   ggml_backend_rpc_device_get_description,
1303+     /*  .get_memory           = */   ggml_backend_rpc_device_get_memory,
1304+     /*  .get_type             = */   ggml_backend_rpc_device_get_type,
1305+     /*  .get_props            = */   ggml_backend_rpc_device_get_props,
1306+     /*  .init_backend         = */   ggml_backend_rpc_device_init,
1307+     /*  .get_buffer_type      = */   ggml_backend_rpc_device_get_buffer_type,
1308+     /*  .get_host_buffer_type = */   NULL ,
1309+     /*  .buffer_from_host_ptr = */   ggml_backend_rpc_device_buffer_from_ptr,
1310+     /*  .supports_op          = */   ggml_backend_rpc_device_supports_op,
1311+     /*  .supports_buft        = */   ggml_backend_rpc_device_supports_buft,
1312+     /*  .offload_op           = */   NULL ,
1313+     /*  .event_new            = */   NULL ,
1314+     /*  .event_free           = */   NULL ,
1315+     /*  .event_synchronize    = */   NULL ,
1316+ };
1317+ 
1318+ //  backend reg interface
1319+ 
1320+ static  const  char  * ggml_backend_rpc_reg_get_name (ggml_backend_reg_t  reg) {
1321+     return  " RPC"  ;
1322+ 
1323+     UNUSED (reg);
1324+ }
1325+ 
1326+ static  size_t  ggml_backend_rpc_reg_get_device_count (ggml_backend_reg_t  reg) {
1327+     return  0 ;
1328+ 
1329+     UNUSED (reg);
1330+ }
1331+ 
1332+ static  ggml_backend_dev_t  ggml_backend_rpc_reg_get_device (ggml_backend_reg_t  reg, size_t  index) {
1333+     GGML_ABORT (" The RPC backend does not have enumerated devices - use ggml_backend_add_device instead"  );
1334+ 
1335+     UNUSED (reg);
1336+     UNUSED (index);
1337+ }
1338+ 
1339+ static  void  * ggml_backend_rpc_get_proc_address (ggml_backend_reg_t  reg, const  char  * name) {
1340+     if  (std::strcmp (name, " ggml_backend_rpc_add_device"  ) == 0 ) {
1341+         return  (void  *)ggml_backend_rpc_add_device;
1342+     }
1343+     return  NULL ;
1344+ 
1345+     UNUSED (reg);
1346+ }
1347+ 
1348+ static  const  struct  ggml_backend_reg_i  ggml_backend_rpc_reg_i = {
1349+     /*  .get_name         = */   ggml_backend_rpc_reg_get_name,
1350+     /*  .get_device_count = */   ggml_backend_rpc_reg_get_device_count,
1351+     /*  .get_device       = */   ggml_backend_rpc_reg_get_device,
1352+     /*  .get_proc_address = */   ggml_backend_rpc_get_proc_address,
1353+ };
1354+ 
1355+ ggml_backend_reg_t  ggml_backend_rpc_reg (void ) {
1356+     static  struct  ggml_backend_reg  ggml_backend_rpc_reg = {
1357+         /*  .iface   = */   ggml_backend_rpc_reg_i,
1358+         /*  .context = */   NULL ,
1359+     };
1360+ 
1361+     return  &ggml_backend_rpc_reg;
1362+ }
1363+ 
1364+ ggml_backend_dev_t  ggml_backend_rpc_add_device (const  char  * endpoint) {
1365+     static  std::unordered_map<std::string, ggml_backend_dev_t > dev_map;
1366+ 
1367+     static  std::mutex mutex;
1368+     std::lock_guard<std::mutex> lock (mutex);
1369+ 
1370+     if  (dev_map.find (endpoint) != dev_map.end ()) {
1371+         return  dev_map[endpoint];
1372+     }
1373+ 
1374+     ggml_backend_rpc_device_context * ctx = new  ggml_backend_rpc_device_context {
1375+         /*  .endpoint = */   endpoint,
1376+         /*  .name     = */   " RPC["   + std::string (endpoint) + " ]"  ,
1377+     };
1378+ 
1379+     ggml_backend_dev_t  dev = new  ggml_backend_device {
1380+         /*  .iface   = */   ggml_backend_rpc_device_i,
1381+         /*  .reg     = */   ggml_backend_rpc_reg (),
1382+         /*  .context = */   ctx,
1383+     };
1384+ 
1385+     dev_map[endpoint] = dev;
1386+ 
1387+     return  dev;
1388+ }
0 commit comments