2525# include < netdb.h>
2626# include < unistd.h>
2727#endif
28- #include < string.h >
28+ #include < cstring >
2929
3030#define UNUSED GGML_UNUSED
3131
@@ -630,22 +630,6 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
630630 return (enum ggml_status)output[0 ];
631631}
632632
633- static bool ggml_backend_rpc_supports_op (ggml_backend_t backend, const ggml_tensor * op) {
634- UNUSED (backend);
635- UNUSED (op);
636- // TODO: call the remote backend and cache the results
637- return true ;
638- }
639-
640- static bool ggml_backend_rpc_supports_buft (ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
641- if (!buft || buft->iface .get_name != ggml_backend_rpc_buffer_type_name) {
642- return false ;
643- }
644- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
645- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
646- return buft_ctx->endpoint == rpc_ctx->endpoint ;
647- }
648-
649633static ggml_backend_i ggml_backend_rpc_interface = {
650634 /* .get_name = */ ggml_backend_rpc_name,
651635 /* .free = */ ggml_backend_rpc_free,
@@ -659,8 +643,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
659643 /* .graph_plan_update = */ NULL ,
660644 /* .graph_plan_compute = */ NULL ,
661645 /* .graph_compute = */ ggml_backend_rpc_graph_compute,
662- /* .supports_op = */ ggml_backend_rpc_supports_op ,
663- /* .supports_buft = */ ggml_backend_rpc_supports_buft ,
646+ /* .supports_op = */ NULL ,
647+ /* .supports_buft = */ NULL ,
664648 /* .offload_op = */ NULL ,
665649 /* .event_record = */ NULL ,
666650 /* .event_wait = */ NULL ,
@@ -691,7 +675,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
691675
692676 ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
693677 /* .iface = */ ggml_backend_rpc_buffer_type_interface,
694- /* .device = */ nullptr ,
678+ /* .device = */ ggml_backend_rpc_add_device (endpoint) ,
695679 /* .context = */ buft_ctx
696680 };
697681 buft_map[endpoint] = buft;
@@ -707,7 +691,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
707691 ggml_backend_t backend = new ggml_backend {
708692 /* .guid = */ ggml_backend_rpc_guid (),
709693 /* .interface = */ ggml_backend_rpc_interface,
710- /* .device = */ nullptr ,
694+ /* .device = */ ggml_backend_rpc_add_device (endpoint) ,
711695 /* .context = */ ctx
712696 };
713697 return backend;
@@ -1189,7 +1173,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
11891173 }
11901174}
11911175
1192- void start_rpc_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1176+ void ggml_backend_rpc_start_rpc_server (ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
11931177 std::string host;
11941178 int port;
11951179 if (!parse_endpoint (endpoint, host, port)) {
@@ -1226,3 +1210,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
12261210 WSACleanup ();
12271211#endif
12281212}
1213+
1214+ // device interface
1215+
1216+ struct ggml_backend_rpc_device_context {
1217+ std::string endpoint;
1218+ std::string name;
1219+ };
1220+
1221+ static const char * ggml_backend_rpc_device_get_name (ggml_backend_dev_t dev) {
1222+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1223+
1224+ return ctx->name .c_str ();
1225+ }
1226+
1227+ static const char * ggml_backend_rpc_device_get_description (ggml_backend_dev_t dev) {
1228+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1229+
1230+ return ctx->name .c_str ();
1231+ }
1232+
1233+ static void ggml_backend_rpc_device_get_memory (ggml_backend_dev_t dev, size_t * free, size_t * total) {
1234+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1235+
1236+ ggml_backend_rpc_get_device_memory (ctx->endpoint .c_str (), free, total);
1237+
1238+ UNUSED (dev);
1239+ }
1240+
1241+ static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type (ggml_backend_dev_t dev) {
1242+ // TODO: obtain value from the server
1243+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
1244+
1245+ UNUSED (dev);
1246+ }
1247+
1248+ static void ggml_backend_rpc_device_get_props (ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1249+ props->name = ggml_backend_rpc_device_get_name (dev);
1250+ props->description = ggml_backend_rpc_device_get_description (dev);
1251+ props->type = ggml_backend_rpc_device_get_type (dev);
1252+ ggml_backend_rpc_device_get_memory (dev, &props->memory_free , &props->memory_total );
1253+ props->caps = {
1254+ /* .async = */ false ,
1255+ /* .host_buffer = */ false ,
1256+ /* .buffer_from_host_ptr = */ false ,
1257+ /* .events = */ false ,
1258+ };
1259+ }
1260+
1261+ static ggml_backend_t ggml_backend_rpc_device_init (ggml_backend_dev_t dev, const char * params) {
1262+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1263+
1264+ return ggml_backend_rpc_init (ctx->endpoint .c_str ());
1265+
1266+ UNUSED (params);
1267+ }
1268+
1269+ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type (ggml_backend_dev_t dev) {
1270+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context ;
1271+
1272+ return ggml_backend_rpc_buffer_type (ctx->endpoint .c_str ());
1273+
1274+ UNUSED (dev);
1275+ }
1276+
1277+ static ggml_backend_buffer_t ggml_backend_rpc_device_buffer_from_ptr (ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
1278+ return ggml_backend_cpu_buffer_from_ptr (ptr, size);
1279+
1280+ UNUSED (dev);
1281+ UNUSED (max_tensor_size);
1282+ }
1283+
1284+ static bool ggml_backend_rpc_device_supports_op (ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1285+ UNUSED (dev);
1286+ UNUSED (op);
1287+ // TODO: call the remote backend and cache the results
1288+ return true ;
1289+ }
1290+
1291+ static bool ggml_backend_rpc_device_supports_buft (ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1292+ if (!buft || buft->iface .get_name != ggml_backend_rpc_buffer_type_name) {
1293+ return false ;
1294+ }
1295+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
1296+ ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context ;
1297+ return buft_ctx->endpoint == dev_ctx->endpoint ;
1298+ }
1299+
1300+ static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1301+ /* .get_name = */ ggml_backend_rpc_device_get_name,
1302+ /* .get_description = */ ggml_backend_rpc_device_get_description,
1303+ /* .get_memory = */ ggml_backend_rpc_device_get_memory,
1304+ /* .get_type = */ ggml_backend_rpc_device_get_type,
1305+ /* .get_props = */ ggml_backend_rpc_device_get_props,
1306+ /* .init_backend = */ ggml_backend_rpc_device_init,
1307+ /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1308+ /* .get_host_buffer_type = */ NULL ,
1309+ /* .buffer_from_host_ptr = */ ggml_backend_rpc_device_buffer_from_ptr,
1310+ /* .supports_op = */ ggml_backend_rpc_device_supports_op,
1311+ /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1312+ /* .offload_op = */ NULL ,
1313+ /* .event_new = */ NULL ,
1314+ /* .event_free = */ NULL ,
1315+ /* .event_synchronize = */ NULL ,
1316+ };
1317+
1318+ // backend reg interface
1319+
1320+ static const char * ggml_backend_rpc_reg_get_name (ggml_backend_reg_t reg) {
1321+ return " RPC" ;
1322+
1323+ UNUSED (reg);
1324+ }
1325+
1326+ static size_t ggml_backend_rpc_reg_get_device_count (ggml_backend_reg_t reg) {
1327+ return 0 ;
1328+
1329+ UNUSED (reg);
1330+ }
1331+
1332+ static ggml_backend_dev_t ggml_backend_rpc_reg_get_device (ggml_backend_reg_t reg, size_t index) {
1333+ GGML_ABORT (" The RPC backend does not have enumerated devices - use ggml_backend_add_device instead" );
1334+
1335+ UNUSED (reg);
1336+ UNUSED (index);
1337+ }
1338+
1339+ static void * ggml_backend_rpc_get_proc_address (ggml_backend_reg_t reg, const char * name) {
1340+ if (std::strcmp (name, " ggml_backend_rpc_add_device" ) == 0 ) {
1341+ return (void *)ggml_backend_rpc_add_device;
1342+ }
1343+ return NULL ;
1344+
1345+ UNUSED (reg);
1346+ }
1347+
1348+ static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
1349+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
1350+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
1351+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
1352+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
1353+ };
1354+
1355+ ggml_backend_reg_t ggml_backend_rpc_reg (void ) {
1356+ static struct ggml_backend_reg ggml_backend_rpc_reg = {
1357+ /* .iface = */ ggml_backend_rpc_reg_i,
1358+ /* .context = */ NULL ,
1359+ };
1360+
1361+ return &ggml_backend_rpc_reg;
1362+ }
1363+
1364+ ggml_backend_dev_t ggml_backend_rpc_add_device (const char * endpoint) {
1365+ static std::unordered_map<std::string, ggml_backend_dev_t > dev_map;
1366+
1367+ static std::mutex mutex;
1368+ std::lock_guard<std::mutex> lock (mutex);
1369+
1370+ if (dev_map.find (endpoint) != dev_map.end ()) {
1371+ return dev_map[endpoint];
1372+ }
1373+
1374+ ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1375+ /* .endpoint = */ endpoint,
1376+ /* .name = */ " RPC[" + std::string (endpoint) + " ]" ,
1377+ };
1378+
1379+ ggml_backend_dev_t dev = new ggml_backend_device {
1380+ /* .iface = */ ggml_backend_rpc_device_i,
1381+ /* .reg = */ ggml_backend_rpc_reg (),
1382+ /* .context = */ ctx,
1383+ };
1384+
1385+ dev_map[endpoint] = dev;
1386+
1387+ return dev;
1388+ }
0 commit comments