Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions ggml/src/ggml-rpc/ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
#include <filesystem>
#include <algorithm>

static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG");

#define LOG_DBG(...) \
do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0)


namespace fs = std::filesystem;

static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
Expand All @@ -47,7 +53,7 @@ struct socket_t {
sockfd_t fd;
socket_t(sockfd_t fd) : fd(fd) {}
~socket_t() {
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
LOG_DBG("[%s] closing socket %d\n", __func__, this->fd);
#ifdef _WIN32
closesocket(this->fd);
#else
Expand Down Expand Up @@ -265,14 +271,14 @@ static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
return nullptr;
}
if (!set_no_delay(sockfd)) {
fprintf(stderr, "Failed to set TCP_NODELAY\n");
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
return nullptr;
}
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
struct hostent * server = gethostbyname(host);
if (server == NULL) {
fprintf(stderr, "Cannot resolve host '%s'\n", host);
GGML_LOG_ERROR("Cannot resolve host '%s'\n", host);
return nullptr;
}
memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
Expand All @@ -289,7 +295,7 @@ static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
return nullptr;
}
if (!set_no_delay(client_socket_fd)) {
fprintf(stderr, "Failed to set TCP_NODELAY\n");
GGML_LOG_ERROR("Failed to set TCP_NODELAY\n");
return nullptr;
}
return client_socket;
Expand All @@ -302,11 +308,11 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
return nullptr;
}
if (!set_reuse_addr(sockfd)) {
fprintf(stderr, "Failed to set SO_REUSEADDR\n");
GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n");
return nullptr;
}
if (inet_addr(host) == INADDR_NONE) {
fprintf(stderr, "Invalid host address: %s\n", host);
GGML_LOG_ERROR("Invalid host address: %s\n", host);
return nullptr;
}
struct sockaddr_in serv_addr;
Expand Down Expand Up @@ -349,7 +355,7 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
return false;
}
if (n == 0) {
GGML_LOG_ERROR("recv returned 0 (peer closed?)\n");
LOG_DBG("recv returned 0 (peer closed?)\n");
return false;
}
bytes_recv += (size_t)n;
Expand Down Expand Up @@ -383,7 +389,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
try {
input.resize(size);
} catch (const std::bad_alloc & e) {
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size);
return false;
}
return recv_data(sockfd, input.data(), size);
Expand Down Expand Up @@ -443,11 +449,11 @@ static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
RPC_STATUS_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
return false;
}
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
return true;
}
Expand Down Expand Up @@ -488,7 +494,7 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
if (!check_server_version(sock)) {
return nullptr;
}
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
sockets[endpoint] = sock;
return sock;
}
Expand Down Expand Up @@ -809,7 +815,7 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
}
auto sock = get_socket(endpoint);
if (sock == nullptr) {
fprintf(stderr, "Failed to connect to %s\n", endpoint);
GGML_LOG_ERROR("Failed to connect to %s\n", endpoint);
return nullptr;
}
size_t alignment = get_alignment(sock);
Expand Down Expand Up @@ -909,7 +915,7 @@ void rpc_server::hello(rpc_msg_hello_rsp & response) {
response.major = RPC_PROTO_MAJOR_VERSION;
response.minor = RPC_PROTO_MINOR_VERSION;
response.patch = RPC_PROTO_PATCH_VERSION;
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
LOG_DBG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
}

bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
Expand All @@ -929,15 +935,15 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
return false;
}

LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
if (tensor->buffer == nullptr) {
//No buffer allocated.
buft = ggml_backend_get_default_buffer_type(backend);
} else {
buft = tensor->buffer->buft;
}

response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
response.alloc_size = ggml_backend_buft_get_alloc_size(buft, tensor);

return true;
}
Expand All @@ -950,29 +956,29 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_
if (buffer != nullptr) {
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
response.remote_size = buffer->size;
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
LOG_DBG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
buffers.insert(buffer);
} else {
GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
LOG_DBG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
}
}

void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
size_t alignment = ggml_backend_buft_get_alignment(buft);
GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
LOG_DBG("[%s] alignment: %lu\n", __func__, alignment);
response.alignment = alignment;
}

void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
size_t max_size = ggml_backend_buft_get_max_size(buft);
GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
LOG_DBG("[%s] max_size: %lu\n", __func__, max_size);
response.max_size = max_size;
}

bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
Expand All @@ -984,7 +990,7 @@ bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rp
}

bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
LOG_DBG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
Expand All @@ -996,7 +1002,7 @@ bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
}

bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
LOG_DBG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
if (buffers.find(buffer) == buffers.end()) {
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
Expand Down Expand Up @@ -1073,7 +1079,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false;
}
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);

// sanitize tensor->data
{
Expand All @@ -1096,7 +1102,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
fs::path cache_file = fs::path(cache_dir) / hash_str;
std::ofstream ofs(cache_file, std::ios::binary);
ofs.write((const char *)data, size);
printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str());
}
ggml_backend_tensor_set(tensor, data, offset, size);
return true;
Expand Down Expand Up @@ -1142,8 +1148,8 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false;
}
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
__func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);

// sanitize tensor->data
{
Expand Down Expand Up @@ -1177,7 +1183,7 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
return false;
}

LOG_DBG("[%s] buffer: %p, data: %p\n", __func__, (void*)tensor->buffer, tensor->data);
// Call the backend's buffer_init_tensor function
ggml_backend_buffer_t buffer = tensor->buffer;
if (buffer && buffer->iface.init_tensor) {
Expand Down Expand Up @@ -1210,7 +1216,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
return false;
}
GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
LOG_DBG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);

// sanitize tensor->data
{
Expand Down Expand Up @@ -1254,7 +1260,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);

if (dst_data + src_size > dst_base + dst_buf_sz) {
GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
GGML_LOG_ERROR("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
" write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
" buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
__func__,
Expand All @@ -1265,8 +1271,8 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
return false;
}

GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
__func__, (void*) src->buffer, (void*) dst->buffer);
LOG_DBG("[%s] src->buffer: %p, dst->buffer: %p\n",
__func__, (void*) src->buffer, (void*) dst->buffer);

response.result = ggml_backend_buffer_copy_tensor(src, dst);
return true;
Expand Down Expand Up @@ -1342,7 +1348,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
return false;
}
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
LOG_DBG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);

size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);

Expand Down Expand Up @@ -1394,7 +1400,7 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
}
// the first command sent by the client must be HELLO
if (cmd != RPC_CMD_HELLO) {
fprintf(stderr, "Expected HELLO command, update client\n");
GGML_LOG_ERROR("Expected HELLO command, update client\n");
return;
}
if (!recv_msg(sockfd, nullptr, 0)) {
Expand All @@ -1411,7 +1417,7 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
}
if (cmd >= RPC_CMD_COUNT) {
// fail fast if the command is invalid
fprintf(stderr, "Unknown command: %d\n", cmd);
GGML_LOG_ERROR("Unknown command: %d\n", cmd);
break;
}
switch (cmd) {
Expand Down Expand Up @@ -1599,7 +1605,7 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break;
}
default: {
fprintf(stderr, "Unknown command: %d\n", cmd);
GGML_LOG_ERROR("Unknown command: %d\n", cmd);
return;
}
}
Expand Down
Loading