Skip to content

Commit 8a5f857

Browse files
firecoperanargerganovthevilledevhbuxiaofeijustinsb
authored
Rpc improvement (#480)
* Add RPC backend in device list to override tensors. * rpc : prevent crashes on invalid input (#9040) Add more checks which prevent RPC server from crashing if invalid input is received from client # Conflicts: # ggml/src/ggml-rpc.cpp * rpc : print error message when failed to connect endpoint (#9042) * Fix RPC error * Add vulkan, sycl to rpc backend * add thread in rpc cpu backend * add cache folder and other improvement in rpc * add header file * support for models with non-512 aligned tensors * rpc : do not wait for response when sending RPC_CMD_SET_TENSOR (#12943) RPC_CMD_SET_TENSOR always returns an empty response and we send this 4 times per token. We can improve TG speed if we don't wait for this empty response. The performance impact of this change depends on the network latency. # Conflicts: # ggml/src/ggml-rpc.cpp * fix(rpc): Improve input validation and error handling (#13069) * fix(rpc): Improve input validation and error handling The `rpc-server` was vulnerable to Denial of Service attacks via several RPC commands (`SET_TENSOR`, `GRAPH_COMPUTE`, etc.). Malformed messages could trigger failed assertions (e.g., invalid `ggml_type`) or out-of-bounds reads/writes leading to `GGML_ABORT` calls, crashing the server process. This PR introduces robust input validation and replaces `abort()` calls with graceful error handling: - **Type Validation:** `deserialize_tensor` now checks if the `tensor->type` is within the valid `GGML_TYPE_COUNT` range *before* calling `ggml_new_tensor_4d`. Returns `nullptr` on invalid type. - **Bounds Checks:** Replaced `GGML_ABORT` in `set_tensor`, `set_tensor_hash`, and `get_tensor` handlers with error logging and returning `false` when data/offset parameters are out of buffer bounds. - **Size Checks:** Added safe arithmetic checks (for overflow) in `graph_compute` when calculating required message sizes based on client-provided `n_nodes` and `n_tensors`. Returns early if the reported sizes conflict with the actual message size or would lead to overflow. - **Error Propagation:** - `create_node` now checks for `nullptr` return values from `deserialize_tensor` and its recursive calls, propagating `nullptr` upwards on failure. Uses `find` instead of `at` for safer map access. - `copy_tensor` now checks for `nullptr` from `deserialize_tensor` and sets the response status to failure if deserialization or bounds checks fail. - `graph_compute` now checks for `nullptr` return from `create_node` and returns failure status correctly. The final return value now reflects the actual computation status. These changes improve the RPC server's resilience against malformed client requests, preventing crashes and ensuring errors are handled more gracefully. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): address pr comments removed comments and unnecessary returns Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): ambiguous nullptr from create_node rpc_server::create_node could previously return nullptr if the input ID was 0 (valid) or if an internal error (deserialization, recursion failure) occurred (invalid). This ambiguity made error handling difficult for the caller (`graph_compute`). This commit clarifies the meaning of nullptr: - `graph_compute` now checks if the input 'id' was non-zero when `create_node` returns nullptr, correctly identifying failures versus intentional null links. - `create_node` avoids recursive calls for zero IDs and propagates nullptr unambiguously on failure during recursion. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): initial zero check in create_node The caller (`graph_compute`) already checks `id != 0` when handling a `nullptr` return from `create_node`, correctly distinguishing intentional null links from actual errors. This makes the initial `if (id == 0)` check redundant. Also removes the log message when a tensor ID is not found in the provided map which was added in this branch. Signed-off-by: Ville Vesilehto <[email protected]> * fix(rpc): Handle get_alloc_size failure in server Check the return value of `server.get_alloc_size` in the RPC server loop. If the call fails, return early to close the connection. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): input size validation in graph_compute Removes detailed, step-by-step size calculations and overflow checks in favor of simpler direct comparisons, assuming 64-bit overflow is unlikely. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): remove extra status code setting Removes the explicit setting of `response.result = GGML_STATUS_FAILED` when `create_node` returns `nullptr` within `graph_compute`. Primary signal is the `false` return value in case of failure. Signed-off-by: Ville Vesilehto <[email protected]> * refactor(rpc): remove redundant check for tensor->type Breaks CI on ubuntu-cpu-make. Tensor type is uint32_t, thus the check is not needed. Signed-off-by: Ville Vesilehto <[email protected]> --------- Signed-off-by: Ville Vesilehto <[email protected]> # Conflicts: # ggml/src/ggml-rpc.cpp * rpc : fix cache directory initialization (#13188) Signed-off-by: xiaofei <[email protected]> # Conflicts: # examples/rpc/rpc-server.cpp * rpc : avoid uninitialized memory in serialize_tensor (#13210) Zero out the name and padding buffers. * fix merge error * Add hello command in RPC * bug fix * add rpc header * fix bug for missing rpc names * add tpc no delay for rpc * add back webui --------- Signed-off-by: Ville Vesilehto <[email protected]> Signed-off-by: xiaofei <[email protected]> Co-authored-by: firecoperana <firecoperana> Co-authored-by: Radoslav Gerganov <[email protected]> Co-authored-by: matt23456 <matt23456> Co-authored-by: Ville Vesilehto <[email protected]> Co-authored-by: xiaofei <[email protected]> Co-authored-by: Justin Santa Barbara <[email protected]>
1 parent 63ef0a3 commit 8a5f857

File tree

11 files changed

+1234
-484
lines changed

11 files changed

+1234
-484
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ include(CheckIncludeFileCXX)
66
set(CMAKE_WARN_UNUSED_CLI YES)
77

88
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
9-
109
set(CMAKE_CXX_STANDARD 17)
1110
set(CMAKE_CXX_STANDARD_REQUIRED true)
1211

common/common.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@
8181
#endif
8282
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
8383
#endif // LLAMA_USE_CURL
84-
84+
#ifdef GGML_USE_RPC
85+
# include "ggml-rpc.h"
86+
#endif
8587
using json = nlohmann::ordered_json;
8688

8789
//
@@ -1004,6 +1006,35 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
10041006
if (arg == "--rpc") {
10051007
CHECK_ARG
10061008
params.rpc_servers = argv[i];
1009+
std::string servers(params.rpc_servers);
1010+
size_t pos = 0;
1011+
while ((pos = servers.find(",")) != std::string::npos) {
1012+
std::string server = servers.substr(0, pos);
1013+
ggml_backend_rpc_buffer_type(server.c_str());
1014+
servers.erase(0, pos + 1);
1015+
}
1016+
ggml_backend_rpc_buffer_type(servers.c_str());
1017+
return true;
1018+
}
1019+
if (arg == "--override-kv") {
1020+
CHECK_ARG
1021+
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
1022+
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
1023+
invalid_param = true;
1024+
return true;
1025+
}
1026+
return true;
1027+
}
1028+
if (arg == "--override-tensor" || arg == "-ot") {
1029+
CHECK_ARG
1030+
/*for (auto endpoint : params.rpc_servers.split)
1031+
{
1032+
1033+
}*/
1034+
if (!parse_buft_overrides(std::string{ argv[i] }, params.tensor_buft_overrides)) {
1035+
fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]);
1036+
invalid_param = true;
1037+
}
10071038
return true;
10081039
}
10091040
if (arg == "--no-mmap") {
@@ -1211,23 +1242,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
12111242
sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
12121243
return true;
12131244
}
1214-
if (arg == "--override-kv") {
1215-
CHECK_ARG
1216-
if (!string_parse_kv_override(argv[i], params.kv_overrides)) {
1217-
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
1218-
invalid_param = true;
1219-
return true;
1220-
}
1221-
return true;
1222-
}
1223-
if (arg == "--override-tensor" || arg == "-ot") {
1224-
CHECK_ARG
1225-
if (!parse_buft_overrides(std::string{argv[i]}, params.tensor_buft_overrides)) {
1226-
fprintf(stderr, "error: Invalid tensor buffer type override: %s\n", argv[i]);
1227-
invalid_param = true;
1228-
}
1229-
return true;
1230-
}
1245+
12311246
if (arg == "--offload-policy" || arg == "-op") {
12321247
CHECK_ARG
12331248
auto p = string_split_pairs<int,int>(argv[i], ',');

examples/rpc/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1-
add_executable(rpc-server rpc-server.cpp)
2-
target_link_libraries(rpc-server PRIVATE ggml llama)
1+
set(TARGET rpc-server)
2+
add_executable(${TARGET} rpc-server.cpp)
3+
target_link_libraries(${TARGET} PRIVATE ggml)
4+
target_compile_features(${TARGET} PRIVATE cxx_std_17)

examples/rpc/rpc-server.cpp

Lines changed: 199 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,33 +5,166 @@
55
#ifdef GGML_USE_METAL
66
#include "ggml-metal.h"
77
#endif
8+
#ifdef GGML_USE_VULKAN
9+
#include "ggml-vulkan.h"
10+
#endif
11+
#ifdef GGML_USE_SYCL
12+
#include "ggml-sycl.h"
13+
#endif
814

915
#include "ggml-rpc.h"
1016
#ifdef _WIN32
17+
# define DIRECTORY_SEPARATOR '\\'
18+
# define NOMINMAX
19+
# include <locale>
1120
# include <windows.h>
21+
# include <fcntl.h>
22+
# include <io.h>
1223
#else
24+
# define DIRECTORY_SEPARATOR '/'
1325
# include <unistd.h>
26+
# include <sys/stat.h>
1427
#endif
1528
#include <string>
1629
#include <stdio.h>
30+
#include <algorithm>
31+
#include <thread>
32+
#include <fstream>
33+
#include <filesystem>
34+
#include <codecvt>
35+
36+
namespace fs = std::filesystem;
37+
38+
// NOTE: this is copied from common.cpp to avoid linking with libcommon
39+
// returns true if successful, false otherwise
40+
static bool fs_create_directory_with_parents(const std::string& path) {
41+
#ifdef _WIN32
42+
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
43+
std::wstring wpath = converter.from_bytes(path);
44+
45+
// if the path already exists, check whether it's a directory
46+
const DWORD attributes = GetFileAttributesW(wpath.c_str());
47+
if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
48+
return true;
49+
}
50+
51+
size_t pos_slash = 0;
52+
53+
// process path from front to back, procedurally creating directories
54+
while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
55+
const std::wstring subpath = wpath.substr(0, pos_slash);
56+
const wchar_t* test = subpath.c_str();
57+
58+
const bool success = CreateDirectoryW(test, NULL);
59+
if (!success) {
60+
const DWORD error = GetLastError();
61+
62+
// if the path already exists, ensure that it's a directory
63+
if (error == ERROR_ALREADY_EXISTS) {
64+
const DWORD attributes = GetFileAttributesW(subpath.c_str());
65+
if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
66+
return false;
67+
}
68+
}
69+
else {
70+
return false;
71+
}
72+
}
73+
74+
pos_slash += 1;
75+
}
76+
77+
return true;
78+
#else
79+
// if the path already exists, check whether it's a directory
80+
struct stat info;
81+
if (stat(path.c_str(), &info) == 0) {
82+
return S_ISDIR(info.st_mode);
83+
}
84+
85+
size_t pos_slash = 1; // skip leading slashes for directory creation
86+
87+
// process path from front to back, procedurally creating directories
88+
while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
89+
const std::string subpath = path.substr(0, pos_slash);
90+
struct stat info;
91+
92+
// if the path already exists, ensure that it's a directory
93+
if (stat(subpath.c_str(), &info) == 0) {
94+
if (!S_ISDIR(info.st_mode)) {
95+
return false;
96+
}
97+
}
98+
else {
99+
// create parent directories
100+
const int ret = mkdir(subpath.c_str(), 0755);
101+
if (ret != 0) {
102+
return false;
103+
}
104+
}
105+
106+
pos_slash += 1;
107+
}
108+
109+
return true;
110+
#endif // _WIN32
111+
}
112+
113+
// NOTE: this is copied from common.cpp to avoid linking with libcommon
114+
static std::string fs_get_cache_directory() {
115+
std::string cache_directory = "";
116+
auto ensure_trailing_slash = [](std::string p) {
117+
// Make sure to add trailing slash
118+
if (p.back() != DIRECTORY_SEPARATOR) {
119+
p += DIRECTORY_SEPARATOR;
120+
}
121+
return p;
122+
};
123+
if (getenv("LLAMA_CACHE")) {
124+
cache_directory = std::getenv("LLAMA_CACHE");
125+
}
126+
else {
127+
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
128+
if (std::getenv("XDG_CACHE_HOME")) {
129+
cache_directory = std::getenv("XDG_CACHE_HOME");
130+
}
131+
else {
132+
cache_directory = std::getenv("HOME") + std::string("/.cache/");
133+
}
134+
#elif defined(__APPLE__)
135+
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
136+
#elif defined(_WIN32)
137+
cache_directory = std::getenv("LOCALAPPDATA");
138+
#else
139+
# error Unknown architecture
140+
#endif
141+
cache_directory = ensure_trailing_slash(cache_directory);
142+
cache_directory += "llama.cpp";
143+
}
144+
return ensure_trailing_slash(cache_directory);
145+
}
17146

18147
struct rpc_server_params {
19148
std::string host = "127.0.0.1";
20149
int port = 50052;
21150
size_t backend_mem = 0;
151+
bool use_cache = false;
152+
int n_threads = std::max(1U, std::thread::hardware_concurrency() / 2);
22153
};
23154

24-
static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
155+
static void print_usage(int /*argc*/, char** argv, rpc_server_params params) {
25156
fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
26157
fprintf(stderr, "options:\n");
27-
fprintf(stderr, " -h, --help show this help message and exit\n");
28-
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
29-
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
30-
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
158+
fprintf(stderr, " -h, --help show this help message and exit\n");
159+
fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads);
160+
fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
161+
fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
162+
fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
163+
fprintf(stderr, " -c, --cache enable local file cache\n");
31164
fprintf(stderr, "\n");
32165
}
33166

34-
static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) {
167+
static bool rpc_server_params_parse(int argc, char** argv, rpc_server_params& params) {
35168
std::string arg;
36169
for (int i = 1; i < argc; i++) {
37170
arg = argv[i];
@@ -40,23 +173,40 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
40173
return false;
41174
}
42175
params.host = argv[i];
43-
} else if (arg == "-p" || arg == "--port") {
176+
}
177+
else if (arg == "-t" || arg == "--threads") {
178+
if (++i >= argc) {
179+
return false;
180+
}
181+
params.n_threads = std::stoi(argv[i]);
182+
if (params.n_threads <= 0) {
183+
fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads);
184+
return false;
185+
}
186+
}
187+
else if (arg == "-p" || arg == "--port") {
44188
if (++i >= argc) {
45189
return false;
46190
}
47191
params.port = std::stoi(argv[i]);
48192
if (params.port <= 0 || params.port > 65535) {
49193
return false;
50194
}
51-
} else if (arg == "-m" || arg == "--mem") {
195+
}
196+
else if (arg == "-c" || arg == "--cache") {
197+
params.use_cache = true;
198+
}
199+
else if (arg == "-m" || arg == "--mem") {
52200
if (++i >= argc) {
53201
return false;
54202
}
55203
params.backend_mem = std::stoul(argv[i]) * 1024 * 1024;
56-
} else if (arg == "-h" || arg == "--help") {
204+
}
205+
else if (arg == "-h" || arg == "--help") {
57206
print_usage(argc, argv, params);
58207
exit(0);
59-
} else {
208+
}
209+
else {
60210
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
61211
print_usage(argc, argv, params);
62212
exit(0);
@@ -65,7 +215,7 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
65215
return true;
66216
}
67217

68-
static ggml_backend_t create_backend() {
218+
static ggml_backend_t create_backend(const rpc_server_params& params) {
69219
ggml_backend_t backend = NULL;
70220
#ifdef GGML_USE_CUDA
71221
fprintf(stderr, "%s: using CUDA backend\n", __func__);
@@ -79,19 +229,36 @@ static ggml_backend_t create_backend() {
79229
if (!backend) {
80230
fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
81231
}
232+
#elif GGML_USE_VULKAN
233+
fprintf(stderr, "%s: using Vulkan backend\n", __func__);
234+
backend = ggml_backend_vk_init(0); // init device 0
235+
if (!backend) {
236+
fprintf(stderr, "%s: ggml_backend_vulkan_init() failed\n", __func__);
237+
}
238+
#elif GGML_USE_SYCL
239+
fprintf(stderr, "%s: using SYCL backend\n", __func__);
240+
backend = ggml_backend_sycl_init(0); // init device 0
241+
if (!backend) {
242+
fprintf(stderr, "%s: ggml_backend_sycl_init() failed\n", __func__);
243+
}
82244
#endif
83245

84246
// if there aren't GPU Backends fallback to CPU backend
85247
if (!backend) {
86248
fprintf(stderr, "%s: using CPU backend\n", __func__);
87249
backend = ggml_backend_cpu_init();
250+
ggml_backend_cpu_set_n_threads(backend, params.n_threads);
88251
}
89252
return backend;
90253
}
91254

92255
static void get_backend_memory(size_t * free_mem, size_t * total_mem) {
93256
#ifdef GGML_USE_CUDA
94257
ggml_backend_cuda_get_device_memory(0, free_mem, total_mem);
258+
#elif GGML_USE_VULKAN
259+
ggml_backend_vk_get_device_memory(0, free_mem, total_mem);
260+
#elif GGML_USE_SYCL
261+
ggml_backend_sycl_get_device_memory(0, free_mem, total_mem);
95262
#else
96263
#ifdef _WIN32
97264
MEMORYSTATUSEX status;
@@ -125,7 +292,7 @@ int main(int argc, char * argv[]) {
125292
fprintf(stderr, "\n");
126293
}
127294

128-
ggml_backend_t backend = create_backend();
295+
ggml_backend_t backend = create_backend(params);
129296
if (!backend) {
130297
fprintf(stderr, "Failed to create backend\n");
131298
return 1;
@@ -135,11 +302,28 @@ int main(int argc, char * argv[]) {
135302
if (params.backend_mem > 0) {
136303
free_mem = params.backend_mem;
137304
total_mem = params.backend_mem;
138-
} else {
305+
}
306+
else {
139307
get_backend_memory(&free_mem, &total_mem);
140308
}
141-
printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024));
142-
start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem);
309+
const char * cache_dir = nullptr;
310+
std::string cache_dir_str;
311+
if (params.use_cache) {
312+
cache_dir_str = fs_get_cache_directory() + "rpc/";
313+
if (!fs_create_directory_with_parents(cache_dir_str)) {
314+
fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
315+
return 1;
316+
}
317+
cache_dir = cache_dir_str.c_str();
318+
}
319+
printf("Starting RPC server v%d.%d.%d\n",
320+
RPC_PROTO_MAJOR_VERSION,
321+
RPC_PROTO_MINOR_VERSION,
322+
RPC_PROTO_PATCH_VERSION);
323+
printf(" endpoint : %s\n", endpoint.c_str());
324+
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
325+
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
326+
ggml_backend_rpc_start_server(backend, endpoint.c_str(), cache_dir, free_mem, total_mem);
143327
ggml_backend_free(backend);
144328
return 0;
145329
}

0 commit comments

Comments
 (0)