Skip to content

Commit 496666c

Browse files
committed
Merge branch 'parallel' into wip
2 parents cd2cf74 + 073a8eb commit 496666c

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -936,15 +936,21 @@ static ggml_tensor* ggml_backend_tp_node_compute_split(int device_index, ggml_te
936936
}
937937

938938
static bool immediate_compute = true;
939-
static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::function<bool(int, std::set<ggml_tensor*>)> gather_pending, std::function<bool(int, ggml_tensor *, ggml_tensor_parallel_extra *)> compute) {
939+
static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::function<bool(int, std::set<ggml_tensor*>)> gather_pending, std::function<bool(int, ggml_tensor *, ggml_tensor_parallel_extra *)> compute, std::function<void(int, std::set<ggml_tensor*>)> flush_compute) {
940940
std::set<ggml_tensor*> pending_gathers;
941941
for (int node_index = 0; node_index < cgraph->n_nodes; node_index++) {
942942
auto tensor = cgraph->nodes[node_index];
943943
auto extra = (ggml_tensor_parallel_extra *)tensor->extra;
944944

945945
// wait for async memcpy to finish if needed
946-
if ((extra->needs_src_rejoin || immediate_compute) && pending_gathers.size()) {
947-
if (gather_pending && !gather_pending(node_index, pending_gathers)) {
946+
if (extra->needs_src_rejoin && pending_gathers.size()) {
947+
if (!immediate_compute) {
948+
if (flush_compute) {
949+
flush_compute(node_index, pending_gathers);
950+
}
951+
}
952+
953+
if (!gather_pending(node_index, pending_gathers)) {
948954
return;
949955
}
950956
pending_gathers.clear();
@@ -957,6 +963,12 @@ static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::func
957963
if (extra->has_rejoin) {
958964
pending_gathers.insert(tensor);
959965
}
966+
967+
if (immediate_compute) {
968+
if (flush_compute) {
969+
flush_compute(node_index, pending_gathers);
970+
}
971+
}
960972
}
961973
}
962974

@@ -975,9 +987,13 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
975987
auto device_index = thread->device_index;
976988
auto be = ggml_parallel_backends[device_index];
977989

990+
if (!be->iface.cpy_tensor2d_async) {
991+
GGML_ABORT("Backend %s does not support async tensor copy.\n", be->iface.get_name(be));
992+
}
993+
978994
int rejoins = 0;
979995

980-
auto flush_compute = [&](int node_index) {
996+
auto flush_compute = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
981997
if (backend_graph->n_nodes ) {
982998
auto status = be->iface.graph_compute(be, backend_graph);
983999
if (status != GGML_STATUS_SUCCESS) {
@@ -986,19 +1002,11 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
9861002
backend_graph->n_nodes = 0;
9871003
}
9881004
thread->end = node_index;
989-
};
990-
991-
auto gather_pending = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
992-
flush_compute(node_index);
9931005

9941006
for (auto & tensor : pending_gathers) {
9951007
auto extra = (ggml_tensor_parallel_extra *)tensor->extra;
9961008
auto wrapped = extra->tensors[device_index];
9971009

998-
if (!be->iface.cpy_tensor2d_async) {
999-
GGML_ABORT("Backend %s does not support async tensor copy.\n", be->iface.get_name(be));
1000-
}
1001-
10021010
// async copies
10031011
for (size_t other_device_index = 0; other_device_index < ggml_parallel_devices.size(); other_device_index++) {
10041012
auto other_be = ggml_parallel_backends[other_device_index];
@@ -1018,7 +1026,9 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
10181026
}
10191027
}
10201028
}
1029+
};
10211030

1031+
auto gather_pending = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
10221032
rejoins++;
10231033
// synchronize self and then release peers
10241034
ggml_backend_synchronize(be);
@@ -1043,16 +1053,11 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
10431053
backend_graph->nodes[backend_graph->n_nodes++] = ggml_backend_tp_node_compute_split(device_index, tensor);
10441054
extra->computed[device_index] = true;
10451055

1046-
if (immediate_compute) {
1047-
flush_compute(node_index);
1048-
ggml_backend_synchronize(be);
1049-
}
1050-
10511056
return true;
10521057
};
10531058

1054-
ggml_backend_tp_buffer_compute_graph(cgraph, gather_pending, compute);
1055-
flush_compute(cgraph->n_nodes);
1059+
ggml_backend_tp_buffer_compute_graph(cgraph, gather_pending, compute, flush_compute);
1060+
flush_compute(cgraph->n_nodes, std::set<ggml_tensor*>());
10561061

10571062
thread->done.unlock();
10581063

@@ -1817,7 +1822,7 @@ static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, gg
18171822
ggml_backend_tp_buffer_compute_graph(cgraph, nullptr, [&](int node_index, ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
18181823
do_init(tensor, extra);
18191824
return true;
1820-
});
1825+
}, nullptr);
18211826

18221827
// calculate the sizes needed for gathering the tensors.
18231828
// this must happen on main thread to prevent race conditions on gather tensor setup.
@@ -1849,7 +1854,7 @@ static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, gg
18491854
gather_buft_sizes_cur[device_index] += extra->gather_buft_sizes[device_index];
18501855
}
18511856
return true;
1852-
});
1857+
}, nullptr);
18531858

18541859
// allocate the gather buffers
18551860
for (size_t device_index = 0; device_index < ggml_parallel_devices.size(); device_index++) {

test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# pip install huggingface_hub hf_transfer
2+
import os # Optional for faster downloading
3+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
4+
5+
name = "ubergarm/Qwen3-235B-A22B-GGUF"
6+
7+
from huggingface_hub import snapshot_download
8+
snapshot_download(
9+
repo_id = name,
10+
local_dir = f"/mnt/scrypted-nvr/{name}",
11+
allow_patterns = ["*IQ3_K*"], # Select quant type UD-IQ1_S for 1.58bit
12+
)
13+

0 commit comments

Comments
 (0)