Skip to content

Commit 50c5a53

Browse files
committed
more split view fixes
1 parent c00fe0c commit 50c5a53

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -563,11 +563,6 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
563563
rejoined->op = GGML_OP_NONE;
564564
}
565565

566-
auto view_src = src;
567-
while (view_src->view_src) {
568-
view_src = view_src->view_src;
569-
}
570-
571566
if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
572567
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
573568
auto dev = ggml_parallel_devices[j];
@@ -579,7 +574,7 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
579574

580575
size_t reduce_offset = 0;
581576
for (size_t i = 0; i < ggml_parallel_devices.size(); i++) {
582-
auto view = ggml_backend_tp_clone_tensor(view_src);
577+
auto view = ggml_backend_tp_clone_tensor(src);
583578
src_extra->rejoined_tensor_views[j][i] = view;
584579

585580
view->op = GGML_OP_NONE;
@@ -592,26 +587,26 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
592587
}
593588
}
594589
else if (src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
595-
auto splits = get_row_splits(view_src);
590+
auto splits = get_row_splits(src);
596591
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
597592
auto rejoined = src_extra->converted_tensors[j];
598593

599594
size_t row_offset = 0;
600595
for (size_t i = 0; i < ggml_parallel_devices.size(); i++) {
601-
auto view = ggml_backend_tp_clone_tensor(view_src);
596+
auto view = ggml_backend_tp_clone_tensor(src);
602597
src_extra->rejoined_tensor_views[j][i] = view;
603598

604599
view->op = GGML_OP_NONE;
605600
view->view_src = rejoined;
606601
view->ne[1] = splits.split[i];
607602
// adjust the offset to the start of the row in the destination tensor
608-
view->view_offs = view_src->nb[1] * row_offset;
603+
view->view_offs = src->nb[1] * row_offset;
609604

610605
row_offset += splits.split[j];
611606
}
612607
}
613608
}
614-
else {
609+
else if (src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
615610
// a typical tensor that is split across multiple devices is usually column split.
616611
// this is because the weight matrixes are transposed and row split, resulting in
617612
// column split resilt. this can not be concatenated memorywise (rowwise).
@@ -621,25 +616,28 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
621616
// A A B B C C D D
622617
// A A B B C C D D
623618
// A A B B C C D D
624-
auto splits = get_col_splits(view_src);
619+
auto splits = get_col_splits(src);
625620
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
626621
auto rejoined = src_extra->converted_tensors[j];
627622

628623
size_t col_offset = 0;
629624
for (size_t i = 0; i < ggml_parallel_devices.size(); i++) {
630-
auto view = ggml_backend_tp_clone_tensor(view_src);
625+
auto view = ggml_backend_tp_clone_tensor(src);
631626
src_extra->rejoined_tensor_views[j][i] = view;
632627

633628
view->op = GGML_OP_NONE;
634629
view->view_src = rejoined;
635630
view->ne[0] = splits.split[i];
636631
// adjust the offset to the start of the column in the destination tensor
637-
view->view_offs = view_src->nb[1] / view_src->ne[0] * col_offset;
632+
view->view_offs = src->nb[1] / src->ne[0] * col_offset;
638633

639634
col_offset += splits.split[j];
640635
}
641636
}
642637
}
638+
else {
639+
GGML_ABORT("Tensor %s is split as %d, but rejoin requested.\n", src->name, src_extra->split_tensors);
640+
}
643641
}
644642

645643
static int memdiff_index(const void *a, const void *b, size_t length) {
@@ -1139,7 +1137,7 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
11391137
};
11401138

11411139
auto no_split_view = [](ggml_tensor *src, ggml_tensor_parallel_extra *src_extra) {
1142-
if (src_extra->split_tensors && src->view_src) {
1140+
if (src_extra->split_tensors == GGML_TP_SPLIT_VIEW) {
11431141
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", src->name);
11441142
}
11451143
};

0 commit comments

Comments
 (0)