Skip to content

Commit 41386cf

Browse files
authored
rpc : report actual free memory (ggml-org#16616)
* rpc : report actual free memory Start reporting the free memory on every device instead of using fixed values. Now llama-cli users can get a nice memory breakdown when using RPC devices. * drop --mem in rpc-server
1 parent 3d4e86b commit 41386cf

File tree

3 files changed

+26
-50
lines changed

3 files changed

+26
-50
lines changed

ggml/include/ggml-rpc.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const c
2121
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total);
2222

2323
GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
24-
size_t n_threads, size_t n_devices,
25-
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
24+
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices);
2625

2726
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
2827
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,7 @@ class rpc_server {
939939
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
940940
bool init_tensor(const rpc_msg_init_tensor_req & request);
941941
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
942+
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
942943

943944
private:
944945
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
@@ -1458,14 +1459,28 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
14581459
return true;
14591460
}
14601461

1462+
bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
1463+
uint32_t dev_id = request.device;
1464+
if (dev_id >= backends.size()) {
1465+
return false;
1466+
}
1467+
size_t free, total;
1468+
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
1469+
ggml_backend_dev_memory(dev, &free, &total);
1470+
response.free_mem = free;
1471+
response.total_mem = total;
1472+
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
1473+
return true;
1474+
}
1475+
14611476
rpc_server::~rpc_server() {
14621477
for (auto buffer : buffers) {
14631478
ggml_backend_buffer_free(buffer);
14641479
}
14651480
}
14661481

14671482
static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
1468-
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
1483+
sockfd_t sockfd) {
14691484
rpc_server server(backends, cache_dir);
14701485
uint8_t cmd;
14711486
if (!recv_data(sockfd, &cmd, 1)) {
@@ -1689,15 +1704,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
16891704
if (!recv_msg(sockfd, &request, sizeof(request))) {
16901705
return;
16911706
}
1692-
auto dev_id = request.device;
1693-
if (dev_id >= backends.size()) {
1707+
rpc_msg_get_device_memory_rsp response;
1708+
if (!server.get_device_memory(request, response)) {
16941709
return;
16951710
}
1696-
rpc_msg_get_device_memory_rsp response;
1697-
response.free_mem = free_mem[dev_id];
1698-
response.total_mem = total_mem[dev_id];
1699-
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
1700-
response.free_mem, response.total_mem);
17011711
if (!send_msg(sockfd, &response, sizeof(response))) {
17021712
return;
17031713
}
@@ -1712,15 +1722,12 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
17121722
}
17131723

17141724
void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
1715-
size_t n_threads, size_t n_devices,
1716-
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
1717-
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
1725+
size_t n_threads, size_t n_devices, ggml_backend_dev_t * devices) {
1726+
if (n_devices == 0 || devices == nullptr) {
17181727
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
17191728
return;
17201729
}
17211730
std::vector<ggml_backend_t> backends;
1722-
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
1723-
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
17241731
printf("Starting RPC server v%d.%d.%d\n",
17251732
RPC_PROTO_MAJOR_VERSION,
17261733
RPC_PROTO_MINOR_VERSION,
@@ -1730,8 +1737,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17301737
printf("Devices:\n");
17311738
for (size_t i = 0; i < n_devices; i++) {
17321739
auto dev = devices[i];
1740+
size_t free, total;
1741+
ggml_backend_dev_memory(dev, &free, &total);
17331742
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
1734-
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
1743+
total / 1024 / 1024, free / 1024 / 1024);
17351744
auto backend = ggml_backend_dev_init(dev, nullptr);
17361745
if (!backend) {
17371746
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
@@ -1775,7 +1784,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
17751784
}
17761785
printf("Accepted client connection\n");
17771786
fflush(stdout);
1778-
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
1787+
rpc_serve_client(backends, cache_dir, client_socket->fd);
17791788
printf("Client connection closed\n");
17801789
fflush(stdout);
17811790
}

tools/rpc/rpc-server.cpp

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ struct rpc_server_params {
137137
bool use_cache = false;
138138
int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
139139
std::vector<std::string> devices;
140-
std::vector<size_t> dev_mem;
141140
};
142141

143142
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
@@ -148,7 +147,6 @@ static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
148147
fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
149148
fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
150149
fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
151-
fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
152150
fprintf(stderr, " -c, --cache enable local file cache\n");
153151
fprintf(stderr, "\n");
154152
}
@@ -197,23 +195,6 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
197195
}
198196
} else if (arg == "-c" || arg == "--cache") {
199197
params.use_cache = true;
200-
} else if (arg == "-m" || arg == "--mem") {
201-
if (++i >= argc) {
202-
return false;
203-
}
204-
const std::regex regex{ R"([,/]+)" };
205-
std::string mem_str = argv[i];
206-
std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
207-
std::sregex_token_iterator end;
208-
for ( ; iter != end; ++iter) {
209-
try {
210-
size_t mem = std::stoul(*iter) * 1024 * 1024;
211-
params.dev_mem.push_back(mem);
212-
} catch (const std::exception & ) {
213-
fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
214-
return false;
215-
}
216-
}
217198
} else if (arg == "-h" || arg == "--help") {
218199
print_usage(argc, argv, params);
219200
exit(0);
@@ -293,18 +274,6 @@ int main(int argc, char * argv[]) {
293274
return 1;
294275
}
295276
std::string endpoint = params.host + ":" + std::to_string(params.port);
296-
std::vector<size_t> free_mem, total_mem;
297-
for (size_t i = 0; i < devices.size(); i++) {
298-
if (i < params.dev_mem.size()) {
299-
free_mem.push_back(params.dev_mem[i]);
300-
total_mem.push_back(params.dev_mem[i]);
301-
} else {
302-
size_t free, total;
303-
ggml_backend_dev_memory(devices[i], &free, &total);
304-
free_mem.push_back(free);
305-
total_mem.push_back(total);
306-
}
307-
}
308277
const char * cache_dir = nullptr;
309278
std::string cache_dir_str;
310279
if (params.use_cache) {
@@ -328,7 +297,6 @@ int main(int argc, char * argv[]) {
328297
return 1;
329298
}
330299

331-
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
332-
devices.data(), free_mem.data(), total_mem.data());
300+
start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data());
333301
return 0;
334302
}

0 commit comments

Comments
 (0)