Skip to content

Commit 073a8eb

Browse files
committed
fix
1 parent cfbaccd commit 073a8eb

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -892,14 +892,20 @@ static ggml_tensor* ggml_backend_tp_node_compute_split(int device_index, ggml_te
892892
}
893893

894894
static bool immediate_compute = false;
895-
static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::function<bool(int, std::set<ggml_tensor*>)> gather_pending, std::function<bool(int, ggml_tensor *, ggml_tensor_parallel_extra *)> compute) {
895+
static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::function<bool(int, std::set<ggml_tensor*>)> gather_pending, std::function<bool(int, ggml_tensor *, ggml_tensor_parallel_extra *)> compute, std::function<void(int, std::set<ggml_tensor*>)> flush_compute) {
896896
std::set<ggml_tensor*> pending_gathers;
897897
for (int node_index = 0; node_index < cgraph->n_nodes; node_index++) {
898898
auto tensor = cgraph->nodes[node_index];
899899
auto extra = (ggml_tensor_parallel_extra *)tensor->extra;
900900

901901
// wait for async memcpy to finish if needed
902-
if ((extra->needs_src_rejoin || immediate_compute) && pending_gathers.size()) {
902+
if (extra->needs_src_rejoin && pending_gathers.size()) {
903+
if (!immediate_compute) {
904+
if (flush_compute) {
905+
flush_compute(node_index, pending_gathers);
906+
}
907+
}
908+
903909
if (!gather_pending(node_index, pending_gathers)) {
904910
return;
905911
}
@@ -913,6 +919,12 @@ static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::func
913919
if (extra->has_rejoin) {
914920
pending_gathers.insert(tensor);
915921
}
922+
923+
if (immediate_compute) {
924+
if (flush_compute) {
925+
flush_compute(node_index, pending_gathers);
926+
}
927+
}
916928
}
917929
}
918930

@@ -931,9 +943,13 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
931943
auto device_index = thread->device_index;
932944
auto be = ggml_parallel_backends[device_index];
933945

946+
if (!be->iface.cpy_tensor2d_async) {
947+
GGML_ABORT("Backend %s does not support async tensor copy.\n", be->iface.get_name(be));
948+
}
949+
934950
int rejoins = 0;
935951

936-
auto flush_compute = [&](int node_index) {
952+
auto flush_compute = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
937953
if (backend_graph->n_nodes ) {
938954
auto status = be->iface.graph_compute(be, backend_graph);
939955
if (status != GGML_STATUS_SUCCESS) {
@@ -942,19 +958,11 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
942958
backend_graph->n_nodes = 0;
943959
}
944960
thread->end = node_index;
945-
};
946-
947-
auto gather_pending = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
948-
flush_compute(node_index);
949961

950962
for (auto & tensor : pending_gathers) {
951963
auto extra = (ggml_tensor_parallel_extra *)tensor->extra;
952964
auto wrapped = extra->tensors[device_index];
953965

954-
if (!be->iface.cpy_tensor2d_async) {
955-
GGML_ABORT("Backend %s does not support async tensor copy.\n", be->iface.get_name(be));
956-
}
957-
958966
// async copies
959967
for (size_t other_device_index = 0; other_device_index < ggml_parallel_devices.size(); other_device_index++) {
960968
auto other_be = ggml_parallel_backends[other_device_index];
@@ -974,7 +982,9 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
974982
}
975983
}
976984
}
985+
};
977986

987+
auto gather_pending = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
978988
rejoins++;
979989
// synchronize self and then release peers
980990
ggml_backend_synchronize(be);
@@ -999,16 +1009,11 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
9991009
backend_graph->nodes[backend_graph->n_nodes++] = ggml_backend_tp_node_compute_split(device_index, tensor);
10001010
extra->computed[device_index] = true;
10011011

1002-
if (immediate_compute) {
1003-
flush_compute(node_index);
1004-
ggml_backend_synchronize(be);
1005-
}
1006-
10071012
return true;
10081013
};
10091014

1010-
ggml_backend_tp_buffer_compute_graph(cgraph, gather_pending, compute);
1011-
flush_compute(cgraph->n_nodes);
1015+
ggml_backend_tp_buffer_compute_graph(cgraph, gather_pending, compute, flush_compute);
1016+
flush_compute(cgraph->n_nodes, std::set<ggml_tensor*>());
10121017

10131018
thread->done.unlock();
10141019

@@ -1061,7 +1066,7 @@ static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, gg
10611066
gather_buft_sizes_cur[device_index] += extra->gather_buft_sizes[device_index];
10621067
}
10631068
return true;
1064-
});
1069+
}, nullptr);
10651070

10661071
// allocate the gather buffers
10671072
for (size_t device_index = 0; device_index < ggml_parallel_devices.size(); device_index++) {

0 commit comments

Comments
 (0)