@@ -563,11 +563,6 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
563563 rejoined->op = GGML_OP_NONE;
564564 }
565565
566- auto view_src = src;
567- while (view_src->view_src ) {
568- view_src = view_src->view_src ;
569- }
570-
571566 if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
572567 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
573568 auto dev = ggml_parallel_devices[j];
@@ -579,7 +574,7 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
579574
580575 size_t reduce_offset = 0 ;
581576 for (size_t i = 0 ; i < ggml_parallel_devices.size (); i++) {
582- auto view = ggml_backend_tp_clone_tensor (view_src );
577+ auto view = ggml_backend_tp_clone_tensor (src );
583578 src_extra->rejoined_tensor_views [j][i] = view;
584579
585580 view->op = GGML_OP_NONE;
@@ -592,26 +587,26 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
592587 }
593588 }
594589 else if (src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
595- auto splits = get_row_splits (view_src );
590+ auto splits = get_row_splits (src );
596591 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
597592 auto rejoined = src_extra->converted_tensors [j];
598593
599594 size_t row_offset = 0 ;
600595 for (size_t i = 0 ; i < ggml_parallel_devices.size (); i++) {
601- auto view = ggml_backend_tp_clone_tensor (view_src );
596+ auto view = ggml_backend_tp_clone_tensor (src );
602597 src_extra->rejoined_tensor_views [j][i] = view;
603598
604599 view->op = GGML_OP_NONE;
605600 view->view_src = rejoined;
606601 view->ne [1 ] = splits.split [i];
607602 // adjust the offset to the start of the row in the destination tensor
608- view->view_offs = view_src ->nb [1 ] * row_offset;
603+ view->view_offs = src ->nb [1 ] * row_offset;
609604
610605 row_offset += splits.split [j];
611606 }
612607 }
613608 }
614- else {
609+ else if (src_extra-> split_tensors == GGML_TP_SPLIT_COLUMNS) {
615610 // a typical tensor that is split across multiple devices is usually column split.
616611 // this is because the weight matrixes are transposed and row split, resulting in
617612 // column split resilt. this can not be concatenated memorywise (rowwise).
@@ -621,25 +616,28 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
621616 // A A B B C C D D
622617 // A A B B C C D D
623618 // A A B B C C D D
624- auto splits = get_col_splits (view_src );
619+ auto splits = get_col_splits (src );
625620 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
626621 auto rejoined = src_extra->converted_tensors [j];
627622
628623 size_t col_offset = 0 ;
629624 for (size_t i = 0 ; i < ggml_parallel_devices.size (); i++) {
630- auto view = ggml_backend_tp_clone_tensor (view_src );
625+ auto view = ggml_backend_tp_clone_tensor (src );
631626 src_extra->rejoined_tensor_views [j][i] = view;
632627
633628 view->op = GGML_OP_NONE;
634629 view->view_src = rejoined;
635630 view->ne [0 ] = splits.split [i];
636631 // adjust the offset to the start of the column in the destination tensor
637- view->view_offs = view_src ->nb [1 ] / view_src ->ne [0 ] * col_offset;
632+ view->view_offs = src ->nb [1 ] / src ->ne [0 ] * col_offset;
638633
639634 col_offset += splits.split [j];
640635 }
641636 }
642637 }
638+ else {
639+ GGML_ABORT (" Tensor %s is split as %d, but rejoin requested.\n " , src->name , src_extra->split_tensors );
640+ }
643641}
644642
645643static int memdiff_index (const void *a, const void *b, size_t length) {
@@ -1139,7 +1137,7 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
11391137 };
11401138
11411139 auto no_split_view = [](ggml_tensor *src, ggml_tensor_parallel_extra *src_extra) {
1142- if (src_extra->split_tensors && src-> view_src ) {
1140+ if (src_extra->split_tensors == GGML_TP_SPLIT_VIEW ) {
11431141 GGML_ABORT (" Tensor %s has view source tensors, which are not supported for tensor parallelism.\n " , src->name );
11441142 }
11451143 };
0 commit comments