@@ -45,6 +45,7 @@ enum ggml_tp_split_type {
4545 GGML_TP_SPLIT_ROWS = 1 ,
4646 GGML_TP_SPLIT_COLUMNS = 2 ,
4747 GGML_TP_SPLIT_REDUCE = 3 ,
48+ GGML_TP_SPLIT_DIM2 = 4 ,
4849};
4950
5051struct ggml_tensor_parallel_extra {
@@ -225,7 +226,7 @@ static void unwrap_tensor(ggml_tensor * tensor, std::set<ggml_tensor *> & tensor
225226 if (i == 1 && tensor->op == GGML_OP_ROPE) {
226227 wrapped->src [i] = src_extra->tensors [j];
227228 }
228- if (i > 0 && tensor->op == GGML_OP_FLASH_ATTN_EXT) {
229+ if (i == 3 && tensor->op == GGML_OP_FLASH_ATTN_EXT) {
229230 wrapped->src [i] = src_extra->tensors [j];
230231 }
231232
@@ -326,8 +327,9 @@ static bool is_split_compatible(ggml_tensor * tensor) {
326327 case GGML_OP_RESHAPE:
327328 case GGML_OP_PERMUTE:
328329 case GGML_OP_ROPE:
329- // case GGML_OP_FLASH_ATTN_EXT:
330- // case GGML_OP_CPY:
330+ case GGML_OP_FLASH_ATTN_EXT:
331+ case GGML_OP_CPY:
332+ // case GGML_OP_GET_ROWS:
331333 return true ;
332334 default :
333335 return false ;
@@ -393,9 +395,79 @@ static size_t ggml_backend_tp_buffer_type_get_alignment(ggml_backend_buffer_type
393395 GGML_UNUSED (buft);
394396}
395397
396- static ggml_status ensure_split (const ggml_tensor *src) {
398+ static ggml_status ensure_dim2_split (const ggml_tensor *src) {
397399 auto src_extra = (ggml_tensor_parallel_extra *)src->extra ;
398400 if (src_extra->split_tensors ) {
401+ if (src_extra->split_tensors != GGML_TP_SPLIT_ROWS) {
402+ GGML_ABORT (" Tensor %s is already split as %d, but requested to be split as rows.\n " , src->name , src_extra->split_tensors );
403+ }
404+ // this tensor is already split, so we don't need to do anything
405+ return GGML_STATUS_SUCCESS;
406+ }
407+ if (src_extra->converted_tensors [0 ]) {
408+ return GGML_STATUS_SUCCESS;
409+ }
410+
411+ // no actual conversion needs to take place, the split tensors can be
412+ // created by using offsets within the original tensor.
413+ auto splits = get_dim_splits (src->ne [2 ]);
414+ size_t offset = 0 ;
415+ for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
416+ auto split = ggml_backend_tp_clone_tensor (src);
417+ split->op = GGML_OP_NONE;
418+ src_extra->converted_tensors [j] = split;
419+
420+ split->buffer = src_extra->tensors [j]->buffer ;
421+ split->data = (char *) src_extra->tensors [j]->data + offset;
422+
423+ // note that only the dimension needs to be changed, retaining the stride allows
424+ // using the original tensor data for the row split.
425+ split->ne [2 ] = splits.split [j];
426+
427+ offset += src->nb [1 ] / src->ne [2 ] * splits.split [j];
428+ }
429+ return GGML_STATUS_SUCCESS;
430+ }
431+
432+ static ggml_status ensure_row_split (const ggml_tensor *src) {
433+ auto src_extra = (ggml_tensor_parallel_extra *)src->extra ;
434+ if (src_extra->split_tensors ) {
435+ if (src_extra->split_tensors != GGML_TP_SPLIT_ROWS) {
436+ GGML_ABORT (" Tensor %s is already split as %d, but requested to be split as rows.\n " , src->name , src_extra->split_tensors );
437+ }
438+ // this tensor is already split, so we don't need to do anything
439+ return GGML_STATUS_SUCCESS;
440+ }
441+ if (src_extra->converted_tensors [0 ]) {
442+ return GGML_STATUS_SUCCESS;
443+ }
444+
445+ // no actual conversion needs to take place, the split tensors can be
446+ // created by using offsets within the original tensor.
447+ auto splits = get_row_splits (src);
448+ size_t offset = 0 ;
449+ for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
450+ auto split = ggml_backend_tp_clone_tensor (src);
451+ split->op = GGML_OP_NONE;
452+ src_extra->converted_tensors [j] = split;
453+
454+ split->buffer = src_extra->tensors [j]->buffer ;
455+ split->data = (char *) src_extra->tensors [j]->data + offset;
456+
457+ // note that only the dimension needs to be changed, retaining the stride allows
458+ // using the original tensor data for the row split.
459+ split->ne [1 ] = splits.split [j];
460+
461+ offset += src->nb [0 ] / src->ne [1 ] * splits.split [j];
462+ }
463+ }
464+
465+ static ggml_status ensure_column_or_reduce_split (const ggml_tensor *src) {
466+ auto src_extra = (ggml_tensor_parallel_extra *)src->extra ;
467+ if (src_extra->split_tensors ) {
468+ if (src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
469+ GGML_ABORT (" Tensor %s is already split as %d, but requested to be split as columns.\n " , src->name , src_extra->split_tensors );
470+ }
399471 // this tensor is already split, so we don't need to do anything
400472 return GGML_STATUS_SUCCESS;
401473 }
@@ -547,7 +619,7 @@ static ggml_status ensure_rejoined(const ggml_tensor *reason, const ggml_tensor
547619 src_extra->has_rejoin = true ;
548620
549621 // if (reason && reason != src) {
550- // printf("Rejoining tensor for %s %d %d \n", ggml_op_name(reason->op), src->ne[0], src->ne[1] );
622+ // printf("Rejoining tensor for %s %s \n", ggml_op_name(reason->op), ggml_op_name( src->op) );
551623 // }
552624
553625 const auto alignment = ggml_backend_tp_buffer_type_get_alignment (src->buffer ->buft );
@@ -907,16 +979,6 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
907979 continue ;
908980 }
909981
910- auto vs = tensor;
911- while (vs->view_src ) {
912- vs = vs->view_src ;
913- }
914- if ((vs->ne [2 ] > 1 || vs->ne [3 ] > 1 )) {
915- if (!ggml_is_contiguous (vs)) {
916- GGML_ABORT (" Tensor %s has more than 2 dimensions, not supported for TP.\n " , tensor->name );
917- }
918- }
919-
920982 pending_rejoins.insert (tensor);
921983
922984 if (!be->iface .cpy_tensor2d_async ) {
@@ -969,7 +1031,7 @@ static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, gg
9691031 std::set<ggml_backend_tp_buffer_context *> contexts;
9701032 for (auto tensor : tensors) {
9711033 auto extra = (ggml_tensor_parallel_extra *)tensor->extra ;
972- if (tensor->view_src && extra-> needs_src_rejoin ) {
1034+ if (tensor->view_src ) {
9731035 auto view_src = tensor->view_src ;
9741036 auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra ;
9751037 if (view_src_extra->has_rejoin ) {
@@ -1563,14 +1625,20 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
15631625 auto src0 = tensor->src [0 ];
15641626 auto src0_extra = (ggml_tensor_parallel_extra *)src0->extra ;
15651627 }
1566- else if (tensor->op != GGML_OP_UNARY && tensor->op != GGML_OP_FLASH_ATTN_EXT && tensor->op != GGML_OP_ROPE) {
1628+ else if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
1629+ int i = 0 ;
1630+ ensure_row_split (tensor->src [0 ]);
1631+ ensure_dim2_split (tensor->src [1 ]);
1632+ ensure_dim2_split (tensor->src [2 ]);
1633+ }
1634+ else if (tensor->op != GGML_OP_UNARY && tensor->op != GGML_OP_ROPE) {
15671635 // printf("ggml_backend_tp_buffer_init_tensor: splitting tensor %s with op %s\n", tensor->name, ggml_op_name(tensor->op));
15681636 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
15691637 auto src = tensor->src [i];
15701638 if (!src) {
15711639 break ;
15721640 }
1573- ensure_split (src);
1641+ ensure_column_or_reduce_split (src);
15741642 }
15751643 }
15761644 }
@@ -1585,8 +1653,8 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
15851653 bool can_reduce = (src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && !src0_extra->has_rejoin ) || (src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE && !src1_extra->has_rejoin );
15861654 bool double_reduce = src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE;
15871655 if (can_reduce) {
1588- ensure_split (src0);
1589- ensure_split (src1);
1656+ ensure_column_or_reduce_split (src0);
1657+ ensure_column_or_reduce_split (src1);
15901658 extra->split_tensors = GGML_TP_SPLIT_REDUCE;
15911659 split_reduced_add = true ;
15921660 }
@@ -1741,6 +1809,7 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
17411809 wrapped->nb [1 ] = src_extra->tensors [j]->nb [1 ];
17421810 wrapped->nb [2 ] = src_extra->tensors [j]->nb [2 ];
17431811 wrapped->nb [3 ] = src_extra->tensors [j]->nb [3 ];
1812+ ensure_rejoined (nullptr , tensor);
17441813 }
17451814 else if (tensor->op == GGML_OP_PERMUTE) {
17461815 auto src = tensor->src [0 ];
0 commit comments