Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/include/ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ extern "C" {
#endif

#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_MINOR_VERSION 5
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16

Expand Down
149 changes: 135 additions & 14 deletions ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,17 @@ enum rpc_cmd {
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
RPC_CMD_GRAPH_COMPUTE_AND_STORE,
RPC_CMD_GRAPH_RECOMPUTE,
RPC_CMD_COUNT,
};

static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14");

// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
const int MAX_STORED_GRAPHS = 64;
const uint64_t INVALID_GRAPH_ID = UINT64_MAX;

struct rpc_msg_hello_rsp {
uint8_t major;
Expand Down Expand Up @@ -217,6 +221,20 @@ struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
};

struct rpc_msg_graph_compute_and_store_rsp {
uint8_t result;
uint64_t graph_id;
};

struct rpc_msg_graph_recompute_req {
uint64_t graph_id;
};

struct rpc_msg_graph_recompute_rsp {
uint8_t result;
};

#pragma pack(pop)

// RPC data structures
Expand All @@ -238,6 +256,7 @@ struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
uint64_t curr_graph_id;
};

struct ggml_backend_rpc_buffer_context {
Expand Down Expand Up @@ -592,6 +611,8 @@ static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
RPC_STATUS_ASSERT(status);
}
// HACK: use the extra field for storing the graph ID
tensor->extra = reinterpret_cast<void*>(INVALID_GRAPH_ID);
return GGML_STATUS_SUCCESS;
}

Expand Down Expand Up @@ -815,13 +836,30 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve

static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.result;

GGML_ASSERT(cgraph->n_nodes > 0);
// HACK: we store the graph ID in the first node's extra field
uint64_t stored_graph_id = reinterpret_cast<uint64_t>(cgraph->nodes[0]->extra);
bool reuse_graph = stored_graph_id != INVALID_GRAPH_ID && (stored_graph_id + MAX_STORED_GRAPHS > rpc_ctx->curr_graph_id);
if (reuse_graph) {
rpc_msg_graph_recompute_req request;
request.graph_id = stored_graph_id;
rpc_msg_graph_recompute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.result;
} else {
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_and_store_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE_AND_STORE, input.data(), input.size(), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
rpc_ctx->curr_graph_id = response.graph_id;
cgraph->nodes[0]->extra = reinterpret_cast<void*>(response.graph_id);
return (enum ggml_status)response.result;
}
}

static ggml_backend_i ggml_backend_rpc_interface = {
Expand Down Expand Up @@ -878,9 +916,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u
ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]";
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name,
/* .curr_graph_id = */ 0,
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
Expand Down Expand Up @@ -921,7 +960,8 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
class rpc_server {
public:
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
: backends(std::move(backends)), cache_dir(cache_dir) {
: backends(std::move(backends)), cache_dir(cache_dir), curr_graph_id(0) {
stored_graphs.resize(MAX_STORED_GRAPHS);
}
~rpc_server();

Expand All @@ -937,22 +977,34 @@ class rpc_server {
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool graph_compute_and_store(const std::vector<uint8_t> & input, rpc_msg_graph_compute_and_store_rsp & response);
bool graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);

struct stored_graph {
uint32_t device;
ggml_context_ptr ctx_ptr;
ggml_cgraph * graph;
};

private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
ggml_tensor * create_node(uint64_t id,
struct ggml_context * ctx,
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
bool store_graph(const std::vector<uint8_t> & input, stored_graph & sg);


std::vector<ggml_backend_t> backends;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
uint64_t curr_graph_id;
// ring buffer for storing graphs
std::vector<stored_graph> stored_graphs;
};

void rpc_server::hello(rpc_msg_hello_rsp & response) {
Expand Down Expand Up @@ -1394,7 +1446,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
return result;
}

bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
bool rpc_server::store_graph(const std::vector<uint8_t> & input, stored_graph & sg) {
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) {
Expand Down Expand Up @@ -1422,7 +1474,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return false;
}
const rpc_tensor * tensors = (const rpc_tensor *)src;
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);

size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);

Expand Down Expand Up @@ -1454,6 +1505,47 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return false;
}
}
sg.ctx_ptr.swap(ctx_ptr);
sg.graph = graph;
sg.device = device;
return true;
}

bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
stored_graph sg;
if (!store_graph(input, sg)) {
return false;
}
uint32_t device = sg.device;
LOG_DBG("[%s] device: %u, input: %zu bytes\n", __func__, device, input.size());
ggml_status status = ggml_backend_graph_compute(backends[device], sg.graph);
response.result = status;
return true;
}

bool rpc_server::graph_compute_and_store(const std::vector<uint8_t> & input, rpc_msg_graph_compute_and_store_rsp & response) {
int graph_slot = curr_graph_id % MAX_STORED_GRAPHS;
if (!store_graph(input, stored_graphs[graph_slot])) {
return false;
}
ggml_cgraph * graph = stored_graphs[graph_slot].graph;
uint32_t device = stored_graphs[graph_slot].device;
LOG_DBG("[%s] device: %u, input: %zu bytes, graph_id: %" PRIu64 "\n", __func__, device, input.size(), curr_graph_id);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status;
response.graph_id = curr_graph_id;
curr_graph_id++;
return true;
}

bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request, rpc_msg_graph_recompute_rsp & response) {
int graph_slot = request.graph_id % MAX_STORED_GRAPHS;
if (stored_graphs[graph_slot].graph == nullptr) {
return false;
}
ggml_cgraph * graph = stored_graphs[graph_slot].graph;
uint32_t device = stored_graphs[graph_slot].device;
LOG_DBG("[%s] device: %u, graph_id: %" PRIu64 "\n", __func__, device, request.graph_id);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status;
return true;
Expand Down Expand Up @@ -1699,6 +1791,35 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
}
break;
}

case RPC_CMD_GRAPH_COMPUTE_AND_STORE: {
std::vector<uint8_t> input;
if (!recv_msg(sockfd, input)) {
return;
}
rpc_msg_graph_compute_and_store_rsp response;
if (!server.graph_compute_and_store(input, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_GRAPH_RECOMPUTE: {
rpc_msg_graph_recompute_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
rpc_msg_graph_recompute_rsp response;
if (!server.graph_recompute(request, response)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
break;
}
case RPC_CMD_GET_DEVICE_MEMORY: {
rpc_msg_get_device_memory_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
Expand Down
Loading