Skip to content

Commit 7bf0bd1

Browse files
committed
rpc : fix alloc size logic
1 parent e072b20 commit 7bf0bd1

File tree

3 files changed

+43
-9
lines changed

3 files changed

+43
-9
lines changed

ggml/include/ggml-rpc.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22

3-
#include "ggml.h"
43
#include "ggml-backend.h"
54

65
#ifdef __cplusplus

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,11 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
11731173

11741174
nth = std::min(nth, nk0);
11751175

1176+
if (nth*nrptg > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1177+
nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
1178+
nrptg = 1;
1179+
}
1180+
11761181
ggml_metal_kargs_set_rows args = {
11771182
/*.nk0 =*/ nk0,
11781183
/*.ne01 =*/ ne01,

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ struct rpc_msg_device_count_rsp {
128128
struct rpc_msg_get_alloc_size_req {
129129
uint32_t device;
130130
rpc_tensor tensor;
131+
rpc_tensor srcs[GGML_MAX_SRC];
131132
};
132133

133134
struct rpc_msg_get_alloc_size_rsp {
@@ -572,6 +573,11 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
572573

573574
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
574575
rpc_tensor result;
576+
if (!tensor) {
577+
memset(&result, 0, sizeof(result));
578+
return result;
579+
}
580+
575581
result.id = reinterpret_cast<uint64_t>(tensor);
576582
result.type = tensor->type;
577583
if (tensor->buffer) {
@@ -753,23 +759,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
753759
}
754760

755761
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
762+
// should we query the remote server for the actual size
763+
bool rpc_get = false;
764+
756765
// See comments in init_tensor.
757-
if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
766+
rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
767+
768+
// ops that require additional memory for fleeting data on certain backends
769+
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
770+
rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
771+
rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
772+
773+
if (rpc_get) {
758774
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
759775
auto sock = get_socket(buft_ctx->endpoint);
760776

761-
rpc_msg_get_alloc_size_req request;
762-
request.device = buft_ctx->device;
763-
request.tensor = serialize_tensor(tensor);
777+
rpc_msg_get_alloc_size_req request = {
778+
/*.device =*/ buft_ctx->device,
779+
/*.tensor =*/ serialize_tensor(tensor),
780+
/*.srcs =*/ {},
781+
};
782+
783+
// .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
784+
for (int i = 0; i < GGML_MAX_SRC; i++) {
785+
request.srcs[i] = serialize_tensor(tensor->src[i]);
786+
}
764787

788+
// TODO: cache the alloc responses to avoid extra RPC calls?
765789
rpc_msg_get_alloc_size_rsp response;
766790
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
767791
RPC_STATUS_ASSERT(status);
768792

769793
return response.alloc_size;
770-
} else {
771-
return ggml_nbytes(tensor);
772794
}
795+
796+
return ggml_nbytes(tensor);
773797
}
774798

775799
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -1017,20 +1041,26 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
10171041
}
10181042
ggml_backend_buffer_type_t buft;
10191043
struct ggml_init_params params {
1020-
/*.mem_size =*/ ggml_tensor_overhead(),
1044+
/*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
10211045
/*.mem_buffer =*/ NULL,
10221046
/*.no_alloc =*/ true,
10231047
};
10241048

10251049
ggml_context_ptr ctx_ptr { ggml_init(params) };
10261050
GGML_ASSERT(ctx_ptr != nullptr);
10271051
ggml_context * ctx = ctx_ptr.get();
1028-
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
10291052

1053+
ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
10301054
if (tensor == nullptr) {
10311055
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
10321056
return false;
10331057
}
1058+
for (int i = 0; i < GGML_MAX_SRC; i++) {
1059+
if (request.srcs[i].id != 0) {
1060+
tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
1061+
}
1062+
}
1063+
10341064
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
10351065
if (tensor->buffer == nullptr) {
10361066
//No buffer allocated.

0 commit comments

Comments
 (0)