Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/include/ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint,

GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem);
ggml_backend_dev_t * devices, size_t * max_mem);

GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint);
Expand Down
56 changes: 39 additions & 17 deletions ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,

class rpc_server {
public:
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
: backends(std::move(backends)), cache_dir(cache_dir) {
rpc_server(std::vector<ggml_backend_t> backends, std::vector<size_t> max_mem, const char * cache_dir)
: backends(std::move(backends)), max_mem(std::move(max_mem)), cache_dir(cache_dir) {
}
~rpc_server();

Expand All @@ -939,6 +939,7 @@ class rpc_server {
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);

private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
Expand All @@ -950,6 +951,7 @@ class rpc_server {


std::vector<ggml_backend_t> backends;
std::vector<size_t> max_mem;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
};
Expand Down Expand Up @@ -1458,15 +1460,39 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return true;
}

static void rpc_dev_memory(ggml_backend_dev_t dev, size_t max_mem, size_t * free, size_t * total) {
// cap total and free memory to the user-specified max memory
size_t free_dev, total_dev;
ggml_backend_dev_memory(dev, &free_dev, &total_dev);
total_dev = (total_dev > max_mem) ? max_mem : total_dev;
free_dev = (free_dev > max_mem) ? max_mem : free_dev;
*free = free_dev;
*total = total_dev;
}

bool rpc_server::get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response) {
uint32_t dev_id = request.device;
if (dev_id >= backends.size()) {
return false;
}
size_t free, total;
ggml_backend_dev_t dev = ggml_backend_get_device(backends[dev_id]);
rpc_dev_memory(dev, max_mem[dev_id], &free, &total);
response.free_mem = free;
response.total_mem = total;
LOG_DBG("[%s] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", __func__, dev_id, response.free_mem, response.total_mem);
return true;
}

rpc_server::~rpc_server() {
for (auto buffer : buffers) {
ggml_backend_buffer_free(buffer);
}
}

static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir,
sockfd_t sockfd, const std::vector<size_t> & free_mem, const std::vector<size_t> & total_mem) {
rpc_server server(backends, cache_dir);
sockfd_t sockfd, const std::vector<size_t> & max_mem) {
rpc_server server(backends, max_mem, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
Expand Down Expand Up @@ -1689,15 +1715,10 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
auto dev_id = request.device;
if (dev_id >= backends.size()) {
rpc_msg_get_device_memory_rsp response;
if (!server.get_device_memory(request, response)) {
return;
}
rpc_msg_get_device_memory_rsp response;
response.free_mem = free_mem[dev_id];
response.total_mem = total_mem[dev_id];
LOG_DBG("[get_device_mem] device: %u, free_mem: %" PRIu64 ", total_mem: %" PRIu64 "\n", dev_id,
response.free_mem, response.total_mem);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
Expand All @@ -1713,14 +1734,13 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const

void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir,
size_t n_threads, size_t n_devices,
ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem) {
if (n_devices == 0 || devices == nullptr || free_mem == nullptr || total_mem == nullptr) {
ggml_backend_dev_t * devices, size_t * max_mem) {
if (n_devices == 0 || devices == nullptr || max_mem == nullptr) {
fprintf(stderr, "Invalid arguments to ggml_backend_rpc_start_server\n");
return;
}
std::vector<ggml_backend_t> backends;
std::vector<size_t> free_mem_vec(free_mem, free_mem + n_devices);
std::vector<size_t> total_mem_vec(total_mem, total_mem + n_devices);
std::vector<size_t> max_mem_vec(max_mem, max_mem + n_devices);
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
Expand All @@ -1730,8 +1750,10 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
printf("Devices:\n");
for (size_t i = 0; i < n_devices; i++) {
auto dev = devices[i];
size_t free, total;
rpc_dev_memory(dev, max_mem[i], &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev),
total_mem[i] / 1024 / 1024, free_mem[i] / 1024 / 1024);
total / 1024 / 1024, free / 1024 / 1024);
auto backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
fprintf(stderr, "Failed to create backend for device %s\n", dev->iface.get_name(dev));
Expand Down Expand Up @@ -1775,7 +1797,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir
}
printf("Accepted client connection\n");
fflush(stdout);
rpc_serve_client(backends, cache_dir, client_socket->fd, free_mem_vec, total_mem_vec);
rpc_serve_client(backends, cache_dir, client_socket->fd, max_mem_vec);
printf("Client connection closed\n");
fflush(stdout);
}
Expand Down
6 changes: 2 additions & 4 deletions tools/rpc/rpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,15 +293,13 @@ int main(int argc, char * argv[]) {
return 1;
}
std::string endpoint = params.host + ":" + std::to_string(params.port);
std::vector<size_t> free_mem, total_mem;
std::vector<size_t> total_mem;
for (size_t i = 0; i < devices.size(); i++) {
if (i < params.dev_mem.size()) {
free_mem.push_back(params.dev_mem[i]);
total_mem.push_back(params.dev_mem[i]);
} else {
size_t free, total;
ggml_backend_dev_memory(devices[i], &free, &total);
free_mem.push_back(free);
total_mem.push_back(total);
}
}
Expand Down Expand Up @@ -329,6 +327,6 @@ int main(int argc, char * argv[]) {
}

start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
devices.data(), free_mem.data(), total_mem.data());
devices.data(), total_mem.data());
return 0;
}
Loading