@@ -128,6 +128,7 @@ struct rpc_msg_device_count_rsp {
128128struct rpc_msg_get_alloc_size_req {
129129 uint32_t device;
130130 rpc_tensor tensor;
131+ rpc_tensor srcs[GGML_MAX_SRC];
131132};
132133
133134struct rpc_msg_get_alloc_size_rsp {
@@ -572,6 +573,11 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
572573
573574static rpc_tensor serialize_tensor (const ggml_tensor * tensor) {
574575 rpc_tensor result;
576+ if (!tensor) {
577+ memset (&result, 0 , sizeof (result));
578+ return result;
579+ }
580+
575581 result.id = reinterpret_cast <uint64_t >(tensor);
576582 result.type = tensor->type ;
577583 if (tensor->buffer ) {
@@ -753,23 +759,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
753759}
754760
755761static size_t ggml_backend_rpc_buffer_type_get_alloc_size (ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
762+ // should we query the remote server for the actual size
763+ bool rpc_get = false ;
764+
756765 // See comments in init_tensor.
757- if (ggml_is_quantized (tensor->type ) && (tensor->ne [0 ] % 512 != 0 ) && (tensor->view_src == nullptr )) {
766+ rpc_get |= ggml_is_quantized (tensor->type ) && (tensor->ne [0 ] % 512 != 0 ) && (tensor->view_src == nullptr );
767+
768+ // ops that require additional memory for fleeting data on certain backends
769+ // ref: https://github.com/ggml-org/llama.cpp/pull/15966
770+ rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
771+ rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
772+
773+ if (rpc_get) {
758774 ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
759775 auto sock = get_socket (buft_ctx->endpoint );
760776
761- rpc_msg_get_alloc_size_req request;
762- request.device = buft_ctx->device ;
763- request.tensor = serialize_tensor (tensor);
777+ rpc_msg_get_alloc_size_req request = {
778+ /* .device =*/ buft_ctx->device ,
779+ /* .tensor =*/ serialize_tensor (tensor),
780+ /* .srcs =*/ {},
781+ };
782+
783+ // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
784+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
785+ request.srcs [i] = serialize_tensor (tensor->src [i]);
786+ }
764787
788+ // TODO: cache the alloc responses to avoid extra RPC calls?
765789 rpc_msg_get_alloc_size_rsp response;
766790 bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof (request), &response, sizeof (response));
767791 RPC_STATUS_ASSERT (status);
768792
769793 return response.alloc_size ;
770- } else {
771- return ggml_nbytes (tensor);
772794 }
795+
796+ return ggml_nbytes (tensor);
773797}
774798
775799static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -1017,20 +1041,26 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
10171041 }
10181042 ggml_backend_buffer_type_t buft;
10191043 struct ggml_init_params params {
1020- /* .mem_size =*/ ggml_tensor_overhead(),
1044+ /* .mem_size =*/ ggml_tensor_overhead()*( 1 + GGML_MAX_SRC) ,
10211045 /* .mem_buffer =*/ NULL ,
10221046 /* .no_alloc =*/ true ,
10231047 };
10241048
10251049 ggml_context_ptr ctx_ptr { ggml_init (params) };
10261050 GGML_ASSERT (ctx_ptr != nullptr );
10271051 ggml_context * ctx = ctx_ptr.get ();
1028- ggml_tensor * tensor = deserialize_tensor (ctx, &request.tensor );
10291052
1053+ ggml_tensor * tensor = deserialize_tensor (ctx, &request.tensor );
10301054 if (tensor == nullptr ) {
10311055 GGML_LOG_ERROR (" Null tensor pointer passed to server get_alloc_size function.\n " );
10321056 return false ;
10331057 }
1058+ for (int i = 0 ; i < GGML_MAX_SRC; i++) {
1059+ if (request.srcs [i].id != 0 ) {
1060+ tensor->src [i] = deserialize_tensor (ctx, &request.srcs [i]);
1061+ }
1062+ }
1063+
10341064 LOG_DBG (" [%s] device: %d, buffer: %p, data: %p\n " , __func__, dev_id, (void *)tensor->buffer , tensor->data );
10351065 if (tensor->buffer == nullptr ) {
10361066 // No buffer allocated.
0 commit comments