Skip to content

Commit 0c58ba3

Browse files
authored
rpc : reuse compute graph buffers (#21299)
Reuse the buffer for the ggml context which is used for creating the compute graph on the server side. This partially addresses a memory leak created by the CUDA backend due to using buffer addresses as cache keys. ref: #21265 ref: #20315
1 parent 57ace0d commit 0c58ba3

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,8 +1009,8 @@ class rpc_server {
10091009
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
10101010

10111011
struct stored_graph {
1012-
ggml_context_ptr ctx_ptr;
1013-
ggml_cgraph * graph;
1012+
std::vector<uint8_t> buffer;
1013+
ggml_cgraph * graph;
10141014
};
10151015

10161016
private:
@@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
15181518
LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors);
15191519

15201520
size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1521-
1521+
if (stored_graphs[device].buffer.size() < buf_size) {
1522+
stored_graphs[device].buffer.resize(buf_size);
1523+
}
15221524
struct ggml_init_params params = {
15231525
/*.mem_size =*/ buf_size,
1524-
/*.mem_buffer =*/ NULL,
1526+
/*.mem_buffer =*/ stored_graphs[device].buffer.data(),
15251527
/*.no_alloc =*/ true,
15261528
};
15271529
ggml_context_ptr ctx_ptr { ggml_init(params) };
@@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
15511553
}
15521554
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
15531555
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
1554-
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
15551556
stored_graphs[device].graph = graph;
15561557
return true;
15571558
}

0 commit comments

Comments
 (0)