Skip to content

Commit 997e304

Browse files
committed
[X] working with rpc, but slow
1 parent 85ea1b8 commit 997e304

File tree

3 files changed

+71
-6
lines changed

3 files changed

+71
-6
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333

3434
namespace fs = std::filesystem;
3535

36+
// Forward declaration for device map access
37+
static std::unordered_map<std::string, ggml_backend_dev_t>& get_rpc_dev_map();
38+
3639
static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB
3740

3841
#ifdef _WIN32
@@ -1760,16 +1763,33 @@ static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
17601763
}
17611764

17621765
static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
1763-
return 0;
1766+
const auto& dev_map = get_rpc_dev_map();
1767+
return dev_map.size();
17641768

17651769
GGML_UNUSED(reg);
17661770
}
17671771

17681772
static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
1769-
GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
1773+
const auto& dev_map = get_rpc_dev_map();
1774+
1775+
if (index >= dev_map.size()) {
1776+
return nullptr;
1777+
}
1778+
1779+
// Convert unordered_map to vector to access by index
1780+
std::vector<ggml_backend_dev_t> devices;
1781+
devices.reserve(dev_map.size());
1782+
for (const auto& pair : dev_map) {
1783+
devices.push_back(pair.second);
1784+
}
1785+
1786+
if (index < devices.size()) {
1787+
return devices[index];
1788+
}
1789+
1790+
return nullptr;
17701791

17711792
GGML_UNUSED(reg);
1772-
GGML_UNUSED(index);
17731793
}
17741794

17751795
static ggml_backend_buffer_type_t ggml_backend_rpc_split_buffer_type(int main_device, const float * tensor_split) {
@@ -1818,8 +1838,14 @@ ggml_backend_reg_t ggml_backend_rpc_reg(void) {
18181838
return &ggml_backend_rpc_reg;
18191839
}
18201840

1821-
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1841+
// Expose the device map for enumeration
1842+
static std::unordered_map<std::string, ggml_backend_dev_t>& get_rpc_dev_map() {
18221843
static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
1844+
return dev_map;
1845+
}
1846+
1847+
ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
1848+
auto& dev_map = get_rpc_dev_map();
18231849

18241850
static std::mutex mutex;
18251851
std::lock_guard<std::mutex> lock(mutex);

src/llama-model.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,13 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s
377377
if (ggml_backend_split_buffer_type_fn) {
378378
size_t dev_index = [&]() {
379379
auto * reg = ggml_backend_dev_backend_reg(dev);
380-
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); ++i) {
381-
if (ggml_backend_reg_dev_get(reg, i) == dev) {
380+
size_t reg_dev_count = ggml_backend_reg_dev_count(reg);
381+
LLAMA_LOG_DEBUG("%s: device %s, reg %s, device count %zu\n", __func__, ggml_backend_dev_name(dev), ggml_backend_reg_name(reg), reg_dev_count);
382+
for (size_t i = 0; i < reg_dev_count; ++i) {
383+
ggml_backend_dev_t reg_dev = ggml_backend_reg_dev_get(reg, i);
384+
LLAMA_LOG_DEBUG("%s: comparing device %s with reg device %s at index %zu\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_name(reg_dev), i);
385+
if (reg_dev == dev) {
386+
LLAMA_LOG_DEBUG("%s: found device %s at index %zu\n", __func__, ggml_backend_dev_name(dev), i);
382387
return i;
383388
}
384389
}

tools/llama-bench/llama-bench.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "common.h"
2222
#include "ggml.h"
23+
#include "ggml-rpc.h"
2324
#include "llama.h"
2425

2526
#ifdef _WIN32
@@ -1827,6 +1828,39 @@ int main(int argc, char ** argv) {
18271828

18281829
cmd_params params = parse_cmd_params(argc, argv);
18291830

1831+
// Register RPC devices if specified
1832+
for (const auto& rpc_servers_str : params.rpc_servers) {
1833+
if (!rpc_servers_str.empty()) {
1834+
auto rpc_servers = string_split<std::string>(rpc_servers_str, ',');
1835+
if (!rpc_servers.empty()) {
1836+
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
1837+
if (!rpc_reg) {
1838+
fprintf(stderr, "%s: failed to find RPC backend\n", __func__);
1839+
return 1;
1840+
}
1841+
1842+
typedef ggml_backend_dev_t (*ggml_backend_rpc_add_device_t)(const char * endpoint);
1843+
ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn =
1844+
(ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
1845+
if (!ggml_backend_rpc_add_device_fn) {
1846+
fprintf(stderr, "%s: failed to find RPC device add function\n", __func__);
1847+
return 1;
1848+
}
1849+
1850+
// Register each RPC device
1851+
for (const std::string & server : rpc_servers) {
1852+
ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
1853+
if (dev) {
1854+
ggml_backend_device_register(dev);
1855+
} else {
1856+
fprintf(stderr, "%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
1857+
return 1;
1858+
}
1859+
}
1860+
}
1861+
}
1862+
}
1863+
18301864
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
18311865
if (!cpu_dev) {
18321866
fprintf(stderr, "%s: error: CPU backend is not loaded\n", __func__);

0 commit comments

Comments
 (0)