Skip to content

Commit 075f1e0

Browse files
committed
rpc : report actual free memory
Start reporting the free memory on every device instead of using fixed values. Now llama-cli users can get a nice memory breakdown when using RPC devices.
1 parent 3e3cb19 commit 075f1e0

File tree

3 files changed

+42
-22
lines changed

3 files changed

+42
-22
lines changed

ggml/include/ggml-rpc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint,
2222

2323
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
2424
size_t n_threads, size_t n_devices,
25-
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
25+
ggml_backend_dev_t * devices, size_t * max_mem);
2626

2727
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
2828
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -920,8 +920,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
920920

921921
class rpc_server {
922922
public:
923-
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
924-
: backends(std::move(backends)), cache_dir(cache_dir) {
923+
rpc_server(std::vector<ggml_backend_t> backends, std::vector<size_t> max_mem, const char * cache_dir)
924+
: backends(std::move(backends)), max_mem(std::move(max_mem)), cache_dir(cache_dir) {
925925
}
926926
~rpc_server();
927927

@@ -939,6 +939,7 @@ class rpc_server {
939939
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
940940
bool init_tensor(const rpc_msg_init_tensor_req & request);
941941
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942+
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
942943

943944
private:
944945
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
@@ -950,6 +951,7 @@ class rpc_server {
950951

951952

952953
std::vector<ggml_backend_t> backends;
954+
std::vector<size_t> max_mem;
953955
const char * cache_dir;
954956
std::unordered_set<ggml_backend_buffer_t> buffers;
955957
};
@@ -1458,15 +1460,39 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14581460
return true;
14591461
}
14601462

1463+
static void rpc_dev_memory(ggml_backend_dev_t dev, size_t max_mem, size_t * free, size_t * total) {
1464+
// cap total and free memory to the user-specified max memory
1465+
size_t free_dev, total_dev;
1466+
ggml_backend_dev_memory(dev, &free_dev, &total_dev);
1467+
total_dev = (total_dev > max_mem) ? max_mem : total_dev;
1468+
free_dev = (free_dev > max_mem) ? max_mem : free_dev;
1469+
*free = free_dev;
1470+
*total = total_dev;
1471+
}
1472+
1473+
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1474+
uint32_t dev_id = request.device;
1475+
if (dev_id >= backends.size()) {
1476+
return false;
1477+
}
1478+
size_t free, total;
1479+
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
1480+
rpc_dev_memory(dev, max_mem[dev_id], &free, &total);
1481+
response.free_mem = free;
1482+
response.total_mem = total;
1483+
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
1484+
return true;
1485+
}
1486+
14611487
rpc_server::~rpc_server() {
14621488
for (auto buffer : buffers) {
14631489
ggml_backend_buffer_free(buffer);
14641490
}
14651491
}
14661492

14671493
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
1468-
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
1469-
rpc_server server(backends, cache_dir);
1494+
sockfd_t sockfd, const std::vector<size_t> & max_mem) {
1495+
rpc_server server(backends, max_mem, cache_dir);
14701496
uint8_t cmd;
14711497
if (!recv_data(sockfd, &cmd, 1)) {
14721498
return;
@@ -1689,15 +1715,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16891715
if (!recv_msg(sockfd, &request, sizeof(request))) {
16901716
return;
16911717
}
1692-
auto dev_id = request.device;
1693-
if (dev_id >= backends.size()) {
1718+
rpc_msg_get_device_memory_rsp response;
1719+
if (!server.get_device_memory(request, response)) {
16941720
return;
16951721
}
1696-
rpc_msg_get_device_memory_rsp response;
1697-
response.free_mem = free_mem[dev_id];
1698-
response.total_mem = total_mem[dev_id];
1699-
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
1700-
response.free_mem, response.total_mem);
17011722
if (!send_msg(sockfd, &response, sizeof(response))) {
17021723
return;
17031724
}
@@ -1713,14 +1734,13 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
17131734

17141735
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
17151736
size_t n_threads, size_t n_devices,
1716-
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
1717-
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
1737+
ggml_backend_dev_t * devices, size_t * max_mem) {
1738+
if (n_devices == 0 || devices == nullptr || max_mem == nullptr) {
17181739
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
17191740
return;
17201741
}
17211742
std::vector<ggml_backend_t> backends;
1722-
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
1723-
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
1743+
std::vector<size_t> max_mem_vec(max_mem, max_mem + n_devices);
17241744
printf("Starting RPC server v%d.%d.%d\n",
17251745
RPC_PROTO_MAJOR_VERSION,
17261746
RPC_PROTO_MINOR_VERSION,
@@ -1730,8 +1750,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17301750
printf("Devices:\n");
17311751
for (size_t i = 0; i < n_devices; i++) {
17321752
auto dev = devices[i];
1753+
size_t free, total;
1754+
rpc_dev_memory(dev, max_mem[i], &free, &total);
17331755
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
1734-
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
1756+
total / 1024 / 1024, free / 1024 / 1024);
17351757
auto backend = ggml_backend_dev_init(dev, nullptr);
17361758
if (!backend) {
17371759
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
@@ -1775,7 +1797,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17751797
}
17761798
printf("Accepted client connection\n");
17771799
fflush(stdout);
1778-
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
1800+
rpc_serve_client(backends, cache_dir, client_socket->fd, max_mem_vec);
17791801
printf("Client connection closed\n");
17801802
fflush(stdout);
17811803
}

tools/rpc/rpc-server.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,13 @@ int main(int argc, char * argv[]) {
293293
return 1;
294294
}
295295
std::string endpoint = params.host + ":" + std::to_string(params.port);
296-
std::vector<size_t> free_mem, total_mem;
296+
std::vector<size_t> total_mem;
297297
for (size_t i = 0; i < devices.size(); i++) {
298298
if (i < params.dev_mem.size()) {
299-
free_mem.push_back(params.dev_mem[i]);
300299
total_mem.push_back(params.dev_mem[i]);
301300
} else {
302301
size_t free, total;
303302
ggml_backend_dev_memory(devices[i], &free, &total);
304-
free_mem.push_back(free);
305303
total_mem.push_back(total);
306304
}
307305
}
@@ -329,6 +327,6 @@ int main(int argc, char * argv[]) {
329327
}
330328

331329
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
332-
devices.data(), free_mem.data(), total_mem.data());
330+
devices.data(), total_mem.data());
333331
return 0;
334332
}

0 commit comments

Comments
 (0)