88#include < vector>
99#include < memory>
1010#include < mutex>
11+ #include < random>
1112#include < unordered_map>
1213#include < unordered_set>
1314#ifdef _WIN32
@@ -876,6 +877,7 @@ class rpc_server {
876877 bool graph_compute (const std::vector<uint8_t > & input, rpc_msg_graph_compute_rsp & response);
877878 bool init_tensor (const rpc_msg_init_tensor_req & request);
878879 bool get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
880+ uint64_t random_id ();
879881
880882private:
881883 bool get_cached_file (uint64_t hash, std::vector<uint8_t > & data);
@@ -885,12 +887,19 @@ class rpc_server {
885887 const std::unordered_map<uint64_t , const rpc_tensor*> & tensor_ptrs,
886888 std::unordered_map<uint64_t , struct ggml_tensor *> & tensor_map);
887889
888-
889890 ggml_backend_t backend;
890891 const char * cache_dir;
891- std::unordered_set<ggml_backend_buffer_t > buffers;
892+ std::random_device rd;
893+ // map from remote_ptr key to actual buffer pointer
894+ std::unordered_map<uint64_t , ggml_backend_buffer_t > buffers;
892895};
893896
897+ uint64_t rpc_server::random_id () {
898+ uint64_t high = static_cast <uint64_t >(rd ()) << 32 ;
899+ uint64_t low = static_cast <uint64_t >(rd ());
900+ return (high | low);
901+ }
902+
894903void rpc_server::hello (rpc_msg_hello_rsp & response) {
895904 response.major = RPC_PROTO_MAJOR_VERSION;
896905 response.minor = RPC_PROTO_MINOR_VERSION;
@@ -934,10 +943,12 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_
934943 response.remote_ptr = 0 ;
935944 response.remote_size = 0 ;
936945 if (buffer != nullptr ) {
937- response.remote_ptr = reinterpret_cast <uint64_t >(buffer);
946+ uint64_t rpk = random_id ();
947+ response.remote_ptr = rpk;
938948 response.remote_size = buffer->size ;
939- GGML_PRINT_DEBUG (" [%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 " , remote_size: %" PRIu64 " \n " , __func__, request.size , response.remote_ptr , response.remote_size );
940- buffers.insert (buffer);
949+ GGML_PRINT_DEBUG (" [%s] size: %" PRIu64 " -> handle: %" PRIu64 " , remote_size: %" PRIu64 " \n " ,
950+ __func__, request.size , rpk, response.remote_size );
951+ buffers[rpk] = buffer;
941952 } else {
942953 GGML_LOG_ERROR (" [%s] size: %" PRIu64 " -> failed\n " , __func__, request.size );
943954 }
@@ -959,35 +970,38 @@ void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
959970
960971bool rpc_server::buffer_get_base (const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
961972 GGML_PRINT_DEBUG (" [%s] remote_ptr: %" PRIx64 " \n " , __func__, request.remote_ptr );
962- ggml_backend_buffer_t buffer = reinterpret_cast < ggml_backend_buffer_t > (request.remote_ptr );
963- if (buffers. find (buffer) == buffers.end ()) {
964- GGML_LOG_ERROR (" [%s] buffer not found\n " , __func__);
973+ auto it = buffers. find (request.remote_ptr );
974+ if (it == buffers.end ()) {
975+ GGML_LOG_ERROR (" [%s] buffer handle not found: % " PRIu64 " \n " , __func__, request. remote_ptr );
965976 return false ;
966977 }
978+ ggml_backend_buffer_t buffer = it->second ;
967979 void * base = ggml_backend_buffer_get_base (buffer);
968980 response.base_ptr = reinterpret_cast <uint64_t >(base);
969981 return true ;
970982}
971983
972984bool rpc_server::free_buffer (const rpc_msg_free_buffer_req & request) {
973985 GGML_PRINT_DEBUG (" [%s] remote_ptr: %" PRIx64 " \n " , __func__, request.remote_ptr );
974- ggml_backend_buffer_t buffer = reinterpret_cast < ggml_backend_buffer_t > (request.remote_ptr );
975- if (buffers. find (buffer) == buffers.end ()) {
976- GGML_LOG_ERROR (" [%s] buffer not found\n " , __func__);
986+ auto it = buffers. find (request.remote_ptr );
987+ if (it == buffers.end ()) {
988+ GGML_LOG_ERROR (" [%s] buffer handle not found: % " PRIu64 " \n " , __func__, request. remote_ptr );
977989 return false ;
978990 }
991+ ggml_backend_buffer_t buffer = it->second ;
979992 ggml_backend_buffer_free (buffer);
980- buffers.erase (buffer );
993+ buffers.erase (it );
981994 return true ;
982995}
983996
984997bool rpc_server::buffer_clear (const rpc_msg_buffer_clear_req & request) {
985998 GGML_PRINT_DEBUG (" [%s] remote_ptr: %" PRIx64 " , value: %u\n " , __func__, request.remote_ptr , request.value );
986- ggml_backend_buffer_t buffer = reinterpret_cast < ggml_backend_buffer_t > (request.remote_ptr );
987- if (buffers. find (buffer) == buffers.end ()) {
988- GGML_LOG_ERROR (" [%s] buffer not found\n " , __func__);
999+ auto it = buffers. find (request.remote_ptr );
1000+ if (it == buffers.end ()) {
1001+ GGML_LOG_ERROR (" [%s] buffer handle not found: % " PRIu64 " \n " , __func__, request. remote_ptr );
9891002 return false ;
9901003 }
1004+ ggml_backend_buffer_t buffer = it->second ;
9911005 ggml_backend_buffer_clear (buffer, request.value );
9921006 return true ;
9931007}
@@ -1011,8 +1025,11 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
10111025 for (uint32_t i = 0 ; i < GGML_MAX_DIMS; i++) {
10121026 result->nb [i] = tensor->nb [i];
10131027 }
1014- result->buffer = reinterpret_cast <ggml_backend_buffer_t >(tensor->buffer );
1015- if (result->buffer && buffers.find (result->buffer ) == buffers.end ()) {
1028+ // convert the remote_ptr handle to an actual buffer pointer
1029+ auto it_buf = buffers.find (tensor->buffer );
1030+ if (it_buf != buffers.end ()) {
1031+ result->buffer = it_buf->second ;
1032+ } else {
10161033 result->buffer = nullptr ;
10171034 }
10181035
@@ -1273,7 +1290,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
12731290 const rpc_tensor * tensor = it_ptr->second ;
12741291
12751292 struct ggml_tensor * result = deserialize_tensor (ctx, tensor);
1276- if (result == nullptr ) {
1293+ if (result == nullptr || result-> buffer == nullptr ) {
12771294 return nullptr ;
12781295 }
12791296 tensor_map[id] = result;
@@ -1366,8 +1383,8 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
13661383}
13671384
13681385rpc_server::~rpc_server () {
1369- for (auto buffer : buffers) {
1370- ggml_backend_buffer_free (buffer );
1386+ for (auto &kv : buffers) {
1387+ ggml_backend_buffer_free (kv. second );
13711388 }
13721389}
13731390
0 commit comments