Skip to content

Commit bc40adb

Browse files
committed
rpc : add backend registry / device interfaces
1 parent c81f3bb commit bc40adb

File tree

5 files changed

+229
-73
lines changed

5 files changed

+229
-73
lines changed

examples/rpc/rpc-server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ int main(int argc, char * argv[]) {
151151
get_backend_memory(&free_mem, &total_mem);
152152
}
153153
printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024));
154-
start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem);
154+
ggml_backend_rpc_start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem);
155155
ggml_backend_free(backend);
156156
return 0;
157157
}

ggml/include/ggml-rpc.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
1717

1818
GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
1919

20-
GGML_API void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
20+
GGML_API void ggml_backend_rpc_start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
21+
22+
GGML_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
23+
24+
GGML_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
2125

2226
#ifdef __cplusplus
2327
}

ggml/src/ggml-backend.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
542542
#include "ggml-blas.h"
543543
#endif
544544

545+
#ifdef GGML_USE_RPC
546+
#include "ggml-rpc.h"
547+
#endif
548+
545549
struct ggml_backend_registry {
546550
std::vector<ggml_backend_reg_t> backends;
547551
std::vector<ggml_backend_dev_t> devices;
@@ -556,6 +560,9 @@ struct ggml_backend_registry {
556560
#ifdef GGML_USE_BLAS
557561
register_backend(ggml_backend_blas_reg());
558562
#endif
563+
#ifdef GGML_USE_RPC
564+
register_backend(ggml_backend_rpc_reg());
565+
#endif
559566

560567
// TODO: sycl, vulkan, kompute, cann
561568

ggml/src/ggml-rpc.cpp

Lines changed: 182 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# include <netdb.h>
2626
# include <unistd.h>
2727
#endif
28-
#include <string.h>
28+
#include <cstring>
2929

3030
#define UNUSED GGML_UNUSED
3131

@@ -630,22 +630,6 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
630630
return (enum ggml_status)output[0];
631631
}
632632

633-
static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
634-
UNUSED(backend);
635-
UNUSED(op);
636-
//TODO: call the remote backend and cache the results
637-
return true;
638-
}
639-
640-
static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
641-
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
642-
return false;
643-
}
644-
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
645-
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
646-
return buft_ctx->endpoint == rpc_ctx->endpoint;
647-
}
648-
649633
static ggml_backend_i ggml_backend_rpc_interface = {
650634
/* .get_name = */ ggml_backend_rpc_name,
651635
/* .free = */ ggml_backend_rpc_free,
@@ -659,8 +643,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
659643
/* .graph_plan_update = */ NULL,
660644
/* .graph_plan_compute = */ NULL,
661645
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
662-
/* .supports_op = */ ggml_backend_rpc_supports_op,
663-
/* .supports_buft = */ ggml_backend_rpc_supports_buft,
646+
/* .supports_op = */ NULL,
647+
/* .supports_buft = */ NULL,
664648
/* .offload_op = */ NULL,
665649
/* .event_record = */ NULL,
666650
/* .event_wait = */ NULL,
@@ -691,7 +675,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
691675

692676
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
693677
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
694-
/* .device = */ nullptr,
678+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
695679
/* .context = */ buft_ctx
696680
};
697681
buft_map[endpoint] = buft;
@@ -707,7 +691,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
707691
ggml_backend_t backend = new ggml_backend {
708692
/* .guid = */ ggml_backend_rpc_guid(),
709693
/* .interface = */ ggml_backend_rpc_interface,
710-
/* .device = */ nullptr,
694+
/* .device = */ ggml_backend_rpc_add_device(endpoint),
711695
/* .context = */ ctx
712696
};
713697
return backend;
@@ -1189,7 +1173,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
11891173
}
11901174
}
11911175

1192-
void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
1176+
void ggml_backend_rpc_start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
11931177
std::string host;
11941178
int port;
11951179
if (!parse_endpoint(endpoint, host, port)) {
@@ -1226,3 +1210,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
12261210
WSACleanup();
12271211
#endif
12281212
}
1213+
1214+
// device interface
1215+
1216+
struct ggml_backend_rpc_device_context {
1217+
std::string endpoint;
1218+
std::string name;
1219+
};
1220+
1221+
static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
1222+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1223+
1224+
return ctx->name.c_str();
1225+
}
1226+
1227+
static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
1228+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1229+
1230+
return ctx->name.c_str();
1231+
}
1232+
1233+
static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1234+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1235+
1236+
ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
1237+
1238+
UNUSED(dev);
1239+
}
1240+
1241+
static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
1242+
// TODO: obtain value from the server
1243+
return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
1244+
1245+
UNUSED(dev);
1246+
}
1247+
1248+
static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1249+
props->name = ggml_backend_rpc_device_get_name(dev);
1250+
props->description = ggml_backend_rpc_device_get_description(dev);
1251+
props->type = ggml_backend_rpc_device_get_type(dev);
1252+
ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
1253+
props->caps = {
1254+
/* .async = */ false,
1255+
/* .host_buffer = */ false,
1256+
/* .buffer_from_host_ptr = */ false,
1257+
/* .events = */ false,
1258+
};
1259+
}
1260+
1261+
static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
1262+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1263+
1264+
return ggml_backend_rpc_init(ctx->endpoint.c_str());
1265+
1266+
UNUSED(params);
1267+
}
1268+
1269+
static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
1270+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1271+
1272+
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
1273+
1274+
UNUSED(dev);
1275+
}
1276+
1277+
static ggml_backend_buffer_t ggml_backend_rpc_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
1278+
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
1279+
1280+
UNUSED(dev);
1281+
UNUSED(max_tensor_size);
1282+
}
1283+
1284+
static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1285+
UNUSED(dev);
1286+
UNUSED(op);
1287+
//TODO: call the remote backend and cache the results
1288+
return true;
1289+
}
1290+
1291+
static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1292+
if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
1293+
return false;
1294+
}
1295+
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
1296+
ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
1297+
return buft_ctx->endpoint == dev_ctx->endpoint;
1298+
}
1299+
1300+
static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
1301+
/* .get_name = */ ggml_backend_rpc_device_get_name,
1302+
/* .get_description = */ ggml_backend_rpc_device_get_description,
1303+
/* .get_memory = */ ggml_backend_rpc_device_get_memory,
1304+
/* .get_type = */ ggml_backend_rpc_device_get_type,
1305+
/* .get_props = */ ggml_backend_rpc_device_get_props,
1306+
/* .init_backend = */ ggml_backend_rpc_device_init,
1307+
/* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
1308+
/* .get_host_buffer_type = */ NULL,
1309+
/* .buffer_from_host_ptr = */ ggml_backend_rpc_device_buffer_from_ptr,
1310+
/* .supports_op = */ ggml_backend_rpc_device_supports_op,
1311+
/* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
1312+
/* .offload_op = */ NULL,
1313+
/* .event_new = */ NULL,
1314+
/* .event_free = */ NULL,
1315+
/* .event_synchronize = */ NULL,
1316+
};
1317+
1318+
// backend reg interface
1319+
1320+
static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
1321+
return "RPC";
1322+
1323+
UNUSED(reg);
1324+
}
1325+
1326+
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1327+
return 0;
1328+
1329+
UNUSED(reg);
1330+
}
1331+
1332+
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1333+
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1334+
1335+
UNUSED(reg);
1336+
UNUSED(index);
1337+
}
1338+
1339+
static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
1340+
if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
1341+
return (void *)ggml_backend_rpc_add_device;
1342+
}
1343+
return NULL;
1344+
1345+
UNUSED(reg);
1346+
}
1347+
1348+
static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
1349+
/* .get_name = */ ggml_backend_rpc_reg_get_name,
1350+
/* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
1351+
/* .get_device = */ ggml_backend_rpc_reg_get_device,
1352+
/* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
1353+
};
1354+
1355+
ggml_backend_reg_t ggml_backend_rpc_reg(void) {
1356+
static struct ggml_backend_reg ggml_backend_rpc_reg = {
1357+
/* .iface = */ ggml_backend_rpc_reg_i,
1358+
/* .context = */ NULL,
1359+
};
1360+
1361+
return &ggml_backend_rpc_reg;
1362+
}
1363+
1364+
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1365+
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
1366+
1367+
static std::mutex mutex;
1368+
std::lock_guard<std::mutex> lock(mutex);
1369+
1370+
if (dev_map.find(endpoint) != dev_map.end()) {
1371+
return dev_map[endpoint];
1372+
}
1373+
1374+
ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
1375+
/* .endpoint = */ endpoint,
1376+
/* .name = */ "RPC[" + std::string(endpoint) + "]",
1377+
};
1378+
1379+
ggml_backend_dev_t dev = new ggml_backend_device {
1380+
/* .iface = */ ggml_backend_rpc_device_i,
1381+
/* .reg = */ ggml_backend_rpc_reg(),
1382+
/* .context = */ ctx,
1383+
};
1384+
1385+
dev_map[endpoint] = dev;
1386+
1387+
return dev;
1388+
}

0 commit comments

Comments
 (0)