Skip to content

Commit b5cfe56

Browse files
committed
fix aslr leaks of buffer
1 parent e2b7621 commit b5cfe56

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99
#include <memory>
1010
#include <mutex>
11+
#include <random>
1112
#include <unordered_map>
1213
#include <unordered_set>
1314
#ifdef _WIN32
@@ -876,6 +877,7 @@ class rpc_server {
876877
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
877878
bool init_tensor(const rpc_msg_init_tensor_req & request);
878879
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
880+
uint64_t random_id();
879881

880882
private:
881883
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
@@ -885,12 +887,19 @@ class rpc_server {
885887
const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
886888
std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
887889

888-
889890
ggml_backend_t backend;
890891
const char * cache_dir;
891-
std::unordered_set<ggml_backend_buffer_t> buffers;
892+
std::random_device rd;
893+
// map from remote_ptr key to actual buffer pointer
894+
std::unordered_map<uint64_t, ggml_backend_buffer_t> buffers;
892895
};
893896

897+
uint64_t rpc_server::random_id() {
898+
uint64_t high = static_cast<uint64_t>(rd()) << 32;
899+
uint64_t low = static_cast<uint64_t>(rd());
900+
return (high | low);
901+
}
902+
894903
void rpc_server::hello(rpc_msg_hello_rsp & response) {
895904
response.major = RPC_PROTO_MAJOR_VERSION;
896905
response.minor = RPC_PROTO_MINOR_VERSION;
@@ -934,10 +943,12 @@ void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_
934943
response.remote_ptr = 0;
935944
response.remote_size = 0;
936945
if (buffer != nullptr) {
937-
response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
946+
uint64_t rpk = random_id();
947+
response.remote_ptr = rpk;
938948
response.remote_size = buffer->size;
939-
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
940-
buffers.insert(buffer);
949+
GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> handle: %" PRIu64 ", remote_size: %" PRIu64 "\n",
950+
__func__, request.size, rpk, response.remote_size);
951+
buffers[rpk] = buffer;
941952
} else {
942953
GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
943954
}
@@ -959,35 +970,38 @@ void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
959970

960971
bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
961972
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
962-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
963-
if (buffers.find(buffer) == buffers.end()) {
964-
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
973+
auto it = buffers.find(request.remote_ptr);
974+
if (it == buffers.end()) {
975+
GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr);
965976
return false;
966977
}
978+
ggml_backend_buffer_t buffer = it->second;
967979
void * base = ggml_backend_buffer_get_base(buffer);
968980
response.base_ptr = reinterpret_cast<uint64_t>(base);
969981
return true;
970982
}
971983

972984
bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
973985
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
974-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
975-
if (buffers.find(buffer) == buffers.end()) {
976-
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
986+
auto it = buffers.find(request.remote_ptr);
987+
if (it == buffers.end()) {
988+
GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr);
977989
return false;
978990
}
991+
ggml_backend_buffer_t buffer = it->second;
979992
ggml_backend_buffer_free(buffer);
980-
buffers.erase(buffer);
993+
buffers.erase(it);
981994
return true;
982995
}
983996

984997
bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
985998
GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
986-
ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
987-
if (buffers.find(buffer) == buffers.end()) {
988-
GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
999+
auto it = buffers.find(request.remote_ptr);
1000+
if (it == buffers.end()) {
1001+
GGML_LOG_ERROR("[%s] buffer handle not found: %" PRIu64 "\n", __func__, request.remote_ptr);
9891002
return false;
9901003
}
1004+
ggml_backend_buffer_t buffer = it->second;
9911005
ggml_backend_buffer_clear(buffer, request.value);
9921006
return true;
9931007
}
@@ -1011,8 +1025,11 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
10111025
for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
10121026
result->nb[i] = tensor->nb[i];
10131027
}
1014-
result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
1015-
if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
1028+
// convert the remote_ptr handle to an actual buffer pointer
1029+
auto it_buf = buffers.find(tensor->buffer);
1030+
if (it_buf != buffers.end()) {
1031+
result->buffer = it_buf->second;
1032+
} else {
10161033
result->buffer = nullptr;
10171034
}
10181035

@@ -1273,7 +1290,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
12731290
const rpc_tensor * tensor = it_ptr->second;
12741291

12751292
struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
1276-
if (result == nullptr) {
1293+
if (result == nullptr || result->buffer == nullptr) {
12771294
return nullptr;
12781295
}
12791296
tensor_map[id] = result;
@@ -1366,8 +1383,8 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
13661383
}
13671384

13681385
rpc_server::~rpc_server() {
1369-
for (auto buffer : buffers) {
1370-
ggml_backend_buffer_free(buffer);
1386+
for (auto &kv : buffers) {
1387+
ggml_backend_buffer_free(kv.second);
13711388
}
13721389
}
13731390

0 commit comments

Comments
 (0)