Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/rpc/rpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,10 @@ int main(int argc, char * argv[]) {
}
cache_dir = cache_dir_str.c_str();
}
printf("Starting RPC server\n");
printf("Starting RPC server v%d.%d.%d\n",
RPC_PROTO_MAJOR_VERSION,
RPC_PROTO_MINOR_VERSION,
RPC_PROTO_PATCH_VERSION);
printf(" endpoint : %s\n", endpoint.c_str());
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
Expand Down
3 changes: 3 additions & 0 deletions ggml/include/ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
extern "C" {
#endif

#define RPC_PROTO_MAJOR_VERSION 1
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16

// backend API
Expand Down
54 changes: 53 additions & 1 deletion ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,19 @@ enum rpc_cmd {
RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_INIT_TENSOR,
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_COUNT,
};

// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;

struct rpc_msg_hello_rsp {
uint8_t major;
uint8_t minor;
uint8_t patch;
};

struct rpc_msg_get_alloc_size_req {
rpc_tensor tensor;
};
Expand Down Expand Up @@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm

// RPC client-side implementation

static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
GGML_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
return false;
}
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
return true;
}

static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
Expand Down Expand Up @@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
if (sock == nullptr) {
return nullptr;
}
if (!check_server_version(sock)) {
return nullptr;
}
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
sockets[endpoint] = sock;
return sock;
Expand Down Expand Up @@ -818,6 +842,7 @@ class rpc_server {
}
~rpc_server();

void hello(rpc_msg_hello_rsp & response);
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
void get_alignment(rpc_msg_get_alignment_rsp & response);
void get_max_size(rpc_msg_get_max_size_rsp & response);
Expand Down Expand Up @@ -846,6 +871,13 @@ class rpc_server {
std::unordered_set<ggml_backend_buffer_t> buffers;
};

void rpc_server::hello(rpc_msg_hello_rsp & response) {
response.major = RPC_PROTO_MAJOR_VERSION;
response.minor = RPC_PROTO_MINOR_VERSION;
response.patch = RPC_PROTO_PATCH_VERSION;
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
}

bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
ggml_backend_buffer_type_t buft;
struct ggml_init_params params {
Expand Down Expand Up @@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
rpc_server server(backend, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
}
// the first command sent by the client must be HELLO
if (cmd != RPC_CMD_HELLO) {
fprintf(stderr, "Expected HELLO command, update client\n");
return;
}
if (!recv_msg(sockfd, nullptr, 0)) {
return;
}
rpc_msg_hello_rsp response;
server.hello(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
Expand All @@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break;
}
switch (cmd) {
case RPC_CMD_HELLO: {
// HELLO command is handled above
return;
}
case RPC_CMD_ALLOC_BUFFER: {
rpc_msg_alloc_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
Expand Down
Loading