@@ -1122,6 +1122,72 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11221122 }
11231123 };
11241124
1125+ bool has_init = false ;
1126+ auto create_reduce_op_tensors = [&] {
1127+ if (has_init) {
1128+ GGML_ABORT (" Tensor %s has already been initialized, but is being initialized again.\n " , tensor->name );
1129+ }
1130+ has_init = true ;
1131+ ggml_backend_tp_finish_init_tensor (tensor);
1132+
1133+ // one of these must be a reduce tensor, the other may be a split tensor, unsplit tensor, or even another reduce tensor.
1134+ auto reduce_tensor = src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE ? src0 : src1;
1135+ auto add_tensor = src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE ? src1 : src0;
1136+ auto add_extra = (ggml_tensor_parallel_extra *)add_tensor->extra ;
1137+ auto reduce_extra = (ggml_tensor_parallel_extra *)reduce_tensor->extra ;
1138+
1139+ if (add_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1140+ // double reduce add can simply be added without any views.
1141+ for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
1142+ auto wrapped = extra->tensors [j];
1143+ auto reduce_op = ggml_backend_tp_clone_tensor (wrapped);
1144+ extra->reduce_op_tensors [j] = reduce_op;
1145+ reduce_op->buffer = wrapped->buffer ;
1146+ reduce_op->view_src = wrapped;
1147+ reduce_op->view_offs = 0 ;
1148+ reduce_op->data = wrapped->data ;
1149+
1150+ reduce_op->src [0 ] = reduce_extra->has_rejoin ? reduce_extra->rejoined_tensor_views [j][j] : reduce_extra->tensors [j];
1151+ reduce_op->src [1 ] = add_extra->has_rejoin ? add_extra->rejoined_tensor_views [j][j] : add_extra->tensors [j];
1152+ }
1153+ }
1154+ else {
1155+ auto splits = get_col_splits (tensor);
1156+ size_t col_offset = 0 ;
1157+
1158+ for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
1159+ auto wrapped = extra->tensors [j];
1160+
1161+ // create a col split view of the destination that can be used for reduction
1162+ auto reduce_op = ggml_backend_tp_clone_tensor (wrapped);
1163+ extra->reduce_op_tensors [j] = reduce_op;
1164+ reduce_op->buffer = wrapped->buffer ;
1165+ reduce_op->view_src = wrapped;
1166+ reduce_op->view_offs = col_offset * wrapped->nb [0 ];
1167+ reduce_op->data = wrapped->data + reduce_op->view_offs ;
1168+ reduce_op->ne [0 ] = splits.split [j];
1169+
1170+ // the reduce was rejoined, and the
1171+ auto reduce = reduce_extra->tensors [j];
1172+ if (reduce_extra->has_rejoin ) {
1173+ reduce = reduce_extra->rejoined_tensor_views [j][j];
1174+ }
1175+
1176+ // create a col split view of the reduced tensor
1177+ ensure_reduce_split_views (reduce_tensor);
1178+
1179+ auto reduce_op_src_view = reduce_extra->reduce_split_views [j];
1180+ reduce_op->src [0 ] = reduce_op_src_view;
1181+
1182+ auto add = add_extra->split_tensors == GGML_TP_SPLIT_NONE ? add_extra->converted_tensors [j] : add_extra->tensors [j];
1183+ reduce_op->src [1 ] = add;
1184+
1185+ col_offset += splits.split [j];
1186+ }
1187+
1188+ }
1189+ };
1190+
11251191 auto no_reduce = [&](ggml_tensor *src, ggml_tensor_parallel_extra *src_extra) {
11261192 if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
11271193 ensure_rejoined (tensor, src);
@@ -1151,7 +1217,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11511217 wrapped->src [src_index] = src_extra->tensors [j];
11521218 }
11531219 }
1154- else if (split == GGML_TP_SPLIT_ROWS || GGML_TP_SPLIT_COLUMNS || GGML_TP_SPLIT_DIM2) {
1220+ else if (split == GGML_TP_SPLIT_ROWS || split == GGML_TP_SPLIT_COLUMNS || split == GGML_TP_SPLIT_DIM2) {
11551221 if (src_extra->split_tensors == GGML_TP_SPLIT_NONE) {
11561222 wrapped->src [src_index] = src_extra->converted_tensors [j];
11571223 }
@@ -1162,12 +1228,38 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11621228 GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, src%d is split as %d but requested to be split as %d.\n " , tensor->name , ggml_op_name (tensor->op ), src_index, src_extra->split_tensors , split);
11631229 }
11641230 }
1231+ else if (split == GGML_TP_SPLIT_REDUCE) {
1232+ if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1233+ wrapped->src [src_index] = src_extra->tensors [j];
1234+ }
1235+ else if (src_extra->split_tensors ) {
1236+ wrapped->src [src_index] = src_extra->reduce_split_views [j];
1237+ }
1238+ else {
1239+ wrapped->src [src_index] = src_extra->tensors [j];
1240+ }
1241+ }
11651242 else {
11661243 GGML_ABORT (" NYI\n " );
11671244 }
11681245 }
11691246 };
11701247
1248+ auto check_srcs = [&]() {
1249+ if (src0 && !extra->tensors [0 ]->src [0 ]) {
1250+ GGML_ABORT (" Tensor %s failed to create src0 tensors for tensor parallelism.\n " , tensor->name );
1251+ }
1252+ if (src1 && !extra->tensors [0 ]->src [1 ]) {
1253+ GGML_ABORT (" Tensor %s failed to create src1 tensors for tensor parallelism.\n " , tensor->name );
1254+ }
1255+ if (src2 && !extra->tensors [0 ]->src [2 ]) {
1256+ GGML_ABORT (" Tensor %s failed to create src2 tensors for tensor parallelism.\n " , tensor->name );
1257+ }
1258+ if (src3 && !extra->tensors [0 ]->src [3 ]) {
1259+ GGML_ABORT (" Tensor %s failed to create src3 tensors for tensor parallelism.\n " , tensor->name );
1260+ }
1261+ };
1262+
11711263 switch (tensor->op ) {
11721264 case GGML_OP_ROPE: {
11731265 no_split_view (src0, src0_extra);
@@ -1229,6 +1321,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12291321 set_src_tensor (2 , GGML_TP_SPLIT_COLUMNS);
12301322 set_src_tensor (3 , GGML_TP_SPLIT_DIM2);
12311323 }
1324+ check_srcs ();
12321325 break ;
12331326 }
12341327
@@ -1255,8 +1348,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12551348
12561349 if (src0_split_tensors == GGML_TP_SPLIT_REDUCE && src1_split_tensors == GGML_TP_SPLIT_REDUCE) {
12571350 create_reduce_tensors ();
1258- set_src_tensor (0 , GGML_TP_SPLIT_REDUCE);
1259- set_src_tensor (1 , GGML_TP_SPLIT_REDUCE);
1351+ create_reduce_op_tensors ();
12601352 }
12611353 else if (src0_split_tensors & !src1_split_tensors) {
12621354 auto split_tensors = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
@@ -1276,9 +1368,14 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12761368 set_src_tensor (1 , GGML_TP_SPLIT_ROWS);
12771369 }
12781370 else if (src0_split_tensors == GGML_TP_SPLIT_REDUCE) {
1371+ ensure_reduce_split_views (src1);
1372+ create_reduce_tensors ();
1373+ create_reduce_op_tensors ();
1374+ }
1375+ else if (src1_split_tensors == GGML_TP_SPLIT_REDUCE) {
1376+ ensure_reduce_split_views (src0);
12791377 create_reduce_tensors ();
1280- set_src_tensor (0 , GGML_TP_SPLIT_REDUCE);
1281- set_src_tensor (1 , GGML_TP_SPLIT_REDUCE);
1378+ create_reduce_op_tensors ();
12821379 }
12831380 else {
12841381 GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, src0 is split but src1 is not.\n " , tensor->name , ggml_op_name (tensor->op ));
@@ -1497,7 +1594,9 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
14971594 GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism.\n " , tensor->name , ggml_op_name (tensor->op ));
14981595 }
14991596
1500- ggml_backend_tp_finish_init_tensor (tensor);
1597+ if (!has_init) {
1598+ ggml_backend_tp_finish_init_tensor (tensor);
1599+ }
15011600}
15021601
15031602static enum ggml_status ggml_backend_tp_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
0 commit comments