@@ -892,14 +892,20 @@ static ggml_tensor* ggml_backend_tp_node_compute_split(int device_index, ggml_te
892892}
893893
894894static bool immediate_compute = false ;
895- static void ggml_backend_tp_buffer_compute_graph (ggml_cgraph * cgraph, std::function<bool (int , std::set<ggml_tensor*>)> gather_pending, std::function<bool(int , ggml_tensor *, ggml_tensor_parallel_extra *)> compute) {
895+ static void ggml_backend_tp_buffer_compute_graph (ggml_cgraph * cgraph, std::function<bool (int , std::set<ggml_tensor*>)> gather_pending, std::function<bool(int , ggml_tensor *, ggml_tensor_parallel_extra *)> compute, std::function<void( int , std::set<ggml_tensor*>)> flush_compute ) {
896896 std::set<ggml_tensor*> pending_gathers;
897897 for (int node_index = 0 ; node_index < cgraph->n_nodes ; node_index++) {
898898 auto tensor = cgraph->nodes [node_index];
899899 auto extra = (ggml_tensor_parallel_extra *)tensor->extra ;
900900
901901 // wait for async memcpy to finish if needed
902- if ((extra->needs_src_rejoin || immediate_compute) && pending_gathers.size ()) {
902+ if (extra->needs_src_rejoin && pending_gathers.size ()) {
903+ if (!immediate_compute) {
904+ if (flush_compute) {
905+ flush_compute (node_index, pending_gathers);
906+ }
907+ }
908+
903909 if (!gather_pending (node_index, pending_gathers)) {
904910 return ;
905911 }
@@ -913,6 +919,12 @@ static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::func
913919 if (extra->has_rejoin ) {
914920 pending_gathers.insert (tensor);
915921 }
922+
923+ if (immediate_compute) {
924+ if (flush_compute) {
925+ flush_compute (node_index, pending_gathers);
926+ }
927+ }
916928 }
917929}
918930
@@ -931,9 +943,13 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
931943 auto device_index = thread->device_index ;
932944 auto be = ggml_parallel_backends[device_index];
933945
946+ if (!be->iface .cpy_tensor2d_async ) {
947+ GGML_ABORT (" Backend %s does not support async tensor copy.\n " , be->iface .get_name (be));
948+ }
949+
934950 int rejoins = 0 ;
935951
936- auto flush_compute = [&](int node_index) {
952+ auto flush_compute = [&](int node_index, std::set<ggml_tensor*> pending_gathers ) {
937953 if (backend_graph->n_nodes ) {
938954 auto status = be->iface .graph_compute (be, backend_graph);
939955 if (status != GGML_STATUS_SUCCESS) {
@@ -942,19 +958,11 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
942958 backend_graph->n_nodes = 0 ;
943959 }
944960 thread->end = node_index;
945- };
946-
947- auto gather_pending = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
948- flush_compute (node_index);
949961
950962 for (auto & tensor : pending_gathers) {
951963 auto extra = (ggml_tensor_parallel_extra *)tensor->extra ;
952964 auto wrapped = extra->tensors [device_index];
953965
954- if (!be->iface .cpy_tensor2d_async ) {
955- GGML_ABORT (" Backend %s does not support async tensor copy.\n " , be->iface .get_name (be));
956- }
957-
958966 // async copies
959967 for (size_t other_device_index = 0 ; other_device_index < ggml_parallel_devices.size (); other_device_index++) {
960968 auto other_be = ggml_parallel_backends[other_device_index];
@@ -974,7 +982,9 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
974982 }
975983 }
976984 }
985+ };
977986
987+ auto gather_pending = [&](int node_index, std::set<ggml_tensor*> pending_gathers) {
978988 rejoins++;
979989 // synchronize self and then release peers
980990 ggml_backend_synchronize (be);
@@ -999,16 +1009,11 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
9991009 backend_graph->nodes [backend_graph->n_nodes ++] = ggml_backend_tp_node_compute_split (device_index, tensor);
10001010 extra->computed [device_index] = true ;
10011011
1002- if (immediate_compute) {
1003- flush_compute (node_index);
1004- ggml_backend_synchronize (be);
1005- }
1006-
10071012 return true ;
10081013 };
10091014
1010- ggml_backend_tp_buffer_compute_graph (cgraph, gather_pending, compute);
1011- flush_compute (cgraph->n_nodes );
1015+ ggml_backend_tp_buffer_compute_graph (cgraph, gather_pending, compute, flush_compute );
1016+ flush_compute (cgraph->n_nodes , std::set<ggml_tensor*>() );
10121017
10131018 thread->done .unlock ();
10141019
@@ -1061,7 +1066,7 @@ static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, gg
10611066 gather_buft_sizes_cur[device_index] += extra->gather_buft_sizes [device_index];
10621067 }
10631068 return true ;
1064- });
1069+ }, nullptr );
10651070
10661071 // allocate the gather buffers
10671072 for (size_t device_index = 0 ; device_index < ggml_parallel_devices.size (); device_index++) {
0 commit comments