Skip to content

Commit 34cc50d

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents de57584 + 66b0dbc commit 34cc50d

File tree

9 files changed

+450
-90
lines changed

9 files changed

+450
-90
lines changed

docs/ops.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ Legend:
100100
| SOFT_MAX_BACK ||| 🟡 | 🟡 ||| 🟡 |||
101101
| SQR ||||| 🟡 ||| 🟡 ||
102102
| SQRT ||||| 🟡 |||||
103-
| SSM_CONV |||||||| ||
104-
| SSM_SCAN |||||||| ||
103+
| SSM_CONV |||||||| ||
104+
| SSM_SCAN |||||||| ||
105105
| STEP |||| 🟡 | 🟡 || 🟡 |||
106106
| SUB ||||| 🟡 | 🟡 ||||
107107
| SUM ||||||||||

ggml/include/ggml-rpc.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
2121
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
2222

2323
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
24-
size_t n_threads, size_t n_devices,
25-
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
24+
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
2625

2726
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
2827
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,7 @@ class rpc_server {
939939
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
940940
bool init_tensor(const rpc_msg_init_tensor_req & request);
941941
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942+
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
942943

943944
private:
944945
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
@@ -1458,14 +1459,28 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14581459
return true;
14591460
}
14601461

1462+
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1463+
uint32_t dev_id = request.device;
1464+
if (dev_id >= backends.size()) {
1465+
return false;
1466+
}
1467+
size_t free, total;
1468+
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
1469+
ggml_backend_dev_memory(dev, &free, &total);
1470+
response.free_mem = free;
1471+
response.total_mem = total;
1472+
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
1473+
return true;
1474+
}
1475+
14611476
rpc_server::~rpc_server() {
14621477
for (auto buffer : buffers) {
14631478
ggml_backend_buffer_free(buffer);
14641479
}
14651480
}
14661481

14671482
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
1468-
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
1483+
sockfd_t sockfd) {
14691484
rpc_server server(backends, cache_dir);
14701485
uint8_t cmd;
14711486
if (!recv_data(sockfd, &cmd, 1)) {
@@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16891704
if (!recv_msg(sockfd, &request, sizeof(request))) {
16901705
return;
16911706
}
1692-
auto dev_id = request.device;
1693-
if (dev_id >= backends.size()) {
1707+
rpc_msg_get_device_memory_rsp response;
1708+
if (!server.get_device_memory(request, response)) {
16941709
return;
16951710
}
1696-
rpc_msg_get_device_memory_rsp response;
1697-
response.free_mem = free_mem[dev_id];
1698-
response.total_mem = total_mem[dev_id];
1699-
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
1700-
response.free_mem, response.total_mem);
17011711
if (!send_msg(sockfd, &response, sizeof(response))) {
17021712
return;
17031713
}
@@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
17121722
}
17131723

17141724
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
1715-
size_t n_threads, size_t n_devices,
1716-
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
1717-
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
1725+
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
1726+
if (n_devices == 0 || devices == nullptr) {
17181727
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
17191728
return;
17201729
}
17211730
std::vector<ggml_backend_t> backends;
1722-
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
1723-
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
17241731
printf("Starting RPC server v%d.%d.%d\n",
17251732
RPC_PROTO_MAJOR_VERSION,
17261733
RPC_PROTO_MINOR_VERSION,
@@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17301737
printf("Devices:\n");
17311738
for (size_t i = 0; i < n_devices; i++) {
17321739
auto dev = devices[i];
1740+
size_t free, total;
1741+
ggml_backend_dev_memory(dev, &free, &total);
17331742
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
1734-
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
1743+
total / 1024 / 1024, free / 1024 / 1024);
17351744
auto backend = ggml_backend_dev_init(dev, nullptr);
17361745
if (!backend) {
17371746
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
@@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17751784
}
17761785
printf("Accepted client connection\n");
17771786
fflush(stdout);
1778-
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
1787+
rpc_serve_client(backends, cache_dir, client_socket->fd);
17791788
printf("Client connection closed\n");
17801789
fflush(stdout);
17811790
}

0 commit comments

Comments
 (0)