Skip to content

Commit b1f78e2

Browse files
committed
rpc: check op supporting
Signed-off-by: thxCode <[email protected]>
1 parent 2eea03d commit b1f78e2

File tree

1 file changed

+102
-4
lines changed

1 file changed

+102
-4
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 102 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ enum rpc_cmd {
8686
RPC_CMD_GET_DEVICE_MEMORY,
8787
RPC_CMD_INIT_TENSOR,
8888
RPC_CMD_GET_ALLOC_SIZE,
89+
RPC_CMD_SUPPORT_OP,
8990
RPC_CMD_COUNT,
9091
};
9192

@@ -158,6 +159,11 @@ struct rpc_msg_get_device_memory_rsp {
158159
uint64_t free_mem;
159160
uint64_t total_mem;
160161
};
162+
163+
struct rpc_msg_support_op_rsp {
164+
uint8_t result;
165+
};
166+
161167
#pragma pack(pop)
162168

163169
// RPC data structures
@@ -438,7 +444,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
438444
rpc_tensor result;
439445
result.id = reinterpret_cast<uint64_t>(tensor);
440446
result.type = tensor->type;
441-
if (tensor->buffer) {
447+
if (tensor->buffer && tensor->buffer->context) {
442448
ggml_backend_buffer_t buffer = tensor->buffer;
443449
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
444450
result.buffer = ctx->remote_ptr;
@@ -767,6 +773,31 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, si
767773
get_device_memory(sock, free, total);
768774
}
769775

776+
static bool ggml_backend_rpc_support_op(const char * endpoint, const ggml_tensor * tensor) {
777+
std::vector<uint8_t> input;
778+
{
779+
std::vector<rpc_tensor> tensors;
780+
for (int i = 0; i < GGML_MAX_SRC; i++) {
781+
if (tensor->src[i] == nullptr) {
782+
break;
783+
}
784+
tensors.push_back(serialize_tensor(tensor->src[i]));
785+
}
786+
tensors.push_back(serialize_tensor(tensor));
787+
// serialization format: | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
788+
uint32_t n_tensors = tensors.size();
789+
int input_size = sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
790+
input.resize(input_size, 0);
791+
memcpy(input.data(), &n_tensors, sizeof(n_tensors));
792+
memcpy(input.data() + sizeof(n_tensors), tensors.data(), n_tensors * sizeof(rpc_tensor));
793+
}
794+
rpc_msg_support_op_rsp response;
795+
auto sock = get_socket(endpoint);
796+
bool status = send_rpc_cmd(sock, RPC_CMD_SUPPORT_OP, input.data(), input.size(), &response, sizeof(response));
797+
GGML_ASSERT(status);
798+
return response.result;
799+
}
800+
770801
// RPC server-side implementation
771802

772803
class rpc_server {
@@ -786,6 +817,7 @@ class rpc_server {
786817
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
787818
bool init_tensor(const rpc_msg_init_tensor_req & request);
788819
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
820+
bool support_op(const std::vector<uint8_t> & input, rpc_msg_support_op_rsp & response);
789821

790822
private:
791823
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -829,6 +861,42 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
829861
return true;
830862
}
831863

864+
bool rpc_server::support_op(const std::vector<uint8_t> & input, rpc_msg_support_op_rsp & response) {
865+
// serialization format: | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
866+
if (input.size() < sizeof(uint32_t)) {
867+
GGML_LOG_ERROR("[%s] invalid input size\n", __func__);
868+
return false;
869+
}
870+
uint32_t n_tensors;
871+
memcpy(&n_tensors, input.data(), sizeof(n_tensors));
872+
if (input.size() < sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor)) {
873+
GGML_LOG_ERROR("[%s] invalid input size\n", __func__);
874+
return false;
875+
}
876+
const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(uint32_t));
877+
GGML_PRINT_DEBUG("[%s] n_tensors: %u\n", __func__, n_tensors);
878+
879+
size_t buf_size = ggml_tensor_overhead()*n_tensors;
880+
struct ggml_init_params params {
881+
/*.mem_size =*/ buf_size,
882+
/*.mem_buffer =*/ NULL,
883+
/*.no_alloc =*/ true,
884+
};
885+
struct ggml_context * ctx = ggml_init(params);
886+
ggml_tensor * tensor = deserialize_tensor(ctx, &tensors[n_tensors-1]);
887+
for (uint32_t i = 0; i < n_tensors-1; i++) {
888+
ggml_tensor * src = deserialize_tensor(ctx, &tensors[i]);
889+
tensor->src[i] = src;
890+
}
891+
response.result = true;
892+
if (backend->device->iface.supports_op) {
893+
response.result = backend->device->iface.supports_op(backend->device, tensor);
894+
}
895+
ggml_free(ctx);
896+
897+
return true;
898+
}
899+
832900
void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
833901
ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
834902
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
@@ -1326,6 +1394,20 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
13261394
}
13271395
break;
13281396
}
1397+
case RPC_CMD_SUPPORT_OP: {
1398+
std::vector<uint8_t> input;
1399+
if (!recv_msg(sockfd, input)) {
1400+
return;
1401+
}
1402+
rpc_msg_support_op_rsp response;
1403+
if (!server.support_op(input, response)) {
1404+
return;
1405+
}
1406+
if (!send_msg(sockfd, &response, sizeof(response))) {
1407+
return;
1408+
}
1409+
break;
1410+
}
13291411
default: {
13301412
fprintf(stderr, "Unknown command: %d\n", cmd);
13311413
return;
@@ -1436,10 +1518,26 @@ static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_b
14361518
}
14371519

14381520
static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
1521+
static std::unordered_map<std::string, std::unordered_map<std::string, bool>> caches;
1522+
ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
1523+
1524+
auto &cache = caches[ctx->endpoint];
1525+
std::string key = op->name;
1526+
key += std::to_string(op->type);
1527+
for (int i = 0; i < GGML_MAX_DIMS; i++) {
1528+
key += std::to_string(op->ne[i]);
1529+
}
1530+
key += std::to_string(op->op);
1531+
1532+
auto it = cache.find(key);
1533+
if (it != cache.end()) {
1534+
return it->second;
1535+
}
1536+
bool result = ggml_backend_rpc_support_op(ctx->endpoint.c_str(), op);
1537+
cache[key] = result;
1538+
return result;
1539+
14391540
GGML_UNUSED(dev);
1440-
GGML_UNUSED(op);
1441-
//TODO: call the remote backend and cache the results
1442-
return true;
14431541
}
14441542

14451543
static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {

0 commit comments

Comments
 (0)