@@ -936,15 +936,21 @@ static ggml_tensor* ggml_backend_tp_node_compute_split(int device_index, ggml_te
936936}
937937
938938static bool immediate_compute = true ;
939- static void ggml_backend_tp_buffer_compute_graph (ggml_cgraph * cgraph, std::function<bool (int , std::set<ggml_tensor*>)> gather_pending, std::function<bool(int , ggml_tensor *, ggml_tensor_parallel_extra *)> compute) {
939+ static void ggml_backend_tp_buffer_compute_graph (ggml_cgraph * cgraph, std::function<bool (int , std::set<ggml_tensor*>)> gather_pending, std::function<bool(int , ggml_tensor *, ggml_tensor_parallel_extra *)> compute, std::function<void( int , std::set<ggml_tensor*>)> flush_compute ) {
940940 std::set<ggml_tensor*> pending_gathers;
941941 for (int node_index = 0 ; node_index < cgraph->n_nodes ; node_index++) {
942942 auto tensor = cgraph->nodes [node_index];
943943 auto extra = (ggml_tensor_parallel_extra *)tensor->extra ;
944944
945945 // wait for async memcpy to finish if needed
946- if ((extra->needs_src_rejoin || immediate_compute) && pending_gathers.size ()) {
947- if (gather_pending && !gather_pending (node_index, pending_gathers)) {
946+ if (extra->needs_src_rejoin && pending_gathers.size ()) {
947+ if (!immediate_compute) {
948+ if (flush_compute) {
949+ flush_compute (node_index, pending_gathers);
950+ }
951+ }
952+
953+ if (!gather_pending (node_index, pending_gathers)) {
948954 return ;
949955 }
950956 pending_gathers.clear ();
@@ -957,6 +963,12 @@ static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::func
957963 if (extra->has_rejoin ) {
958964 pending_gathers.insert (tensor);
959965 }
966+
967+ if (immediate_compute) {
968+ if (flush_compute) {
969+ flush_compute (node_index, pending_gathers);
970+ }
971+ }
960972 }
961973}
962974
@@ -975,9 +987,13 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
975987 auto device_index = thread->device_index ;
976988 auto be = ggml_parallel_backends[device_index];
977989
990+ if (!be->iface .cpy_tensor2d_async ) {
991+ GGML_ABORT (" Backend %s does not support async tensor copy.\n " , be->iface .get_name (be));
992+ }
993+
978994 int rejoins = 0 ;
979995
980- auto flush_compute = [&](int node_index) {
996+ auto flush_compute = [&](int node_index, std::set<ggml_tensor*> pending_gathers ) {
981997 if (backend_graph->n_nodes ) {
982998 auto status = be->iface .graph_compute (be, backend_graph);
983999 if (status != GGML_STATUS_SUCCESS) {
@@ -986,19 +1002,11 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
9861002 backend_graph->n_nodes = 0 ;
9871003 }
9881004 thread->end = node_index;
989- };
990-
991- auto gather_pending = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
992- flush_compute (node_index);
9931005
9941006 for (auto & tensor : pending_gathers) {
9951007 auto extra = (ggml_tensor_parallel_extra *)tensor->extra ;
9961008 auto wrapped = extra->tensors [device_index];
9971009
998- if (!be->iface .cpy_tensor2d_async ) {
999- GGML_ABORT (" Backend %s does not support async tensor copy.\n " , be->iface .get_name (be));
1000- }
1001-
10021010 // async copies
10031011 for (size_t other_device_index = 0 ; other_device_index < ggml_parallel_devices.size (); other_device_index++) {
10041012 auto other_be = ggml_parallel_backends[other_device_index];
@@ -1018,7 +1026,9 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
10181026 }
10191027 }
10201028 }
1029+ };
10211030
1031+ auto gather_pending = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
10221032 rejoins++;
10231033 // synchronize self and then release peers
10241034 ggml_backend_synchronize (be);
@@ -1043,16 +1053,11 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
10431053 backend_graph->nodes [backend_graph->n_nodes ++] = ggml_backend_tp_node_compute_split (device_index, tensor);
10441054 extra->computed [device_index] = true ;
10451055
1046- if (immediate_compute) {
1047- flush_compute (node_index);
1048- ggml_backend_synchronize (be);
1049- }
1050-
10511056 return true ;
10521057 };
10531058
1054- ggml_backend_tp_buffer_compute_graph (cgraph, gather_pending, compute);
1055- flush_compute (cgraph->n_nodes );
1059+ ggml_backend_tp_buffer_compute_graph (cgraph, gather_pending, compute, flush_compute );
1060+ flush_compute (cgraph->n_nodes , std::set<ggml_tensor*>() );
10561061
10571062 thread->done .unlock ();
10581063
@@ -1817,7 +1822,7 @@ static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, gg
18171822 ggml_backend_tp_buffer_compute_graph (cgraph, nullptr , [&](int node_index, ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
18181823 do_init (tensor, extra);
18191824 return true ;
1820- });
1825+ }, nullptr );
18211826
18221827 // calculate the sizes needed for gathering the tensors.
18231828 // this must happen on main thread to prevent race conditions on gather tensor setup.
@@ -1849,7 +1854,7 @@ static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, gg
18491854 gather_buft_sizes_cur[device_index] += extra->gather_buft_sizes [device_index];
18501855 }
18511856 return true ;
1852- });
1857+ }, nullptr );
18531858
18541859 // allocate the gather buffers
18551860 for (size_t device_index = 0 ; device_index < ggml_parallel_devices.size (); device_index++) {
0 commit comments