Skip to content

Commit 1674b56

Browse files
committed
fixes
1 parent 2f88e26 commit 1674b56

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
970970
auto src2_extra = src2 ? (ggml_tensor_parallel_extra *)src2->extra : nullptr;
971971
auto src3_extra = src3 ? (ggml_tensor_parallel_extra *)src3->extra : nullptr;
972972

973-
auto create_default_tensors = [&]() {
973+
auto create_default_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
974974
extra->split_tensors = GGML_TP_SPLIT_NONE;
975975
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
976976
auto dev = ggml_parallel_devices[j];
@@ -979,6 +979,10 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
979979
}
980980
};
981981

982+
auto create_default_tensors = [&]() {
983+
create_default_tensors_for(tensor, extra);
984+
};
985+
982986
auto create_reduce_tensors = [&]() {
983987
extra->split_tensors = GGML_TP_SPLIT_REDUCE;
984988
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
@@ -1354,9 +1358,13 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
13541358
}
13551359

13561360
case GGML_OP_RMS_NORM:
1357-
no_split_view(src0, src0_extra);
1358-
if (tensor->view_src) {
1359-
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", tensor->name);
1361+
// auto src0_viewsrc = src0->view_src;
1362+
// auto src0_viewsrc_extra = (ggml_tensor_parallel_extra *)src0_viewsrc->extra;
1363+
// no_split_view(src0, src0_extra);
1364+
if (src0->view_src) {
1365+
ensure_rejoined(tensor, src0->view_src);
1366+
create_default_tensors_for(src0, src0_extra);
1367+
set_src_tensor_for(src0, src0_extra, 0, GGML_TP_SPLIT_NONE);
13601368
}
13611369

13621370
ensure_rejoined(tensor, src0);
@@ -1394,7 +1402,12 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
13941402
}
13951403
else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS) {
13961404
ensure_column_split(src0);
1397-
ensure_column_split(src1);
1405+
if (src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1406+
ensure_reduce_split_views(src1);
1407+
}
1408+
else {
1409+
ensure_column_split(src1);
1410+
}
13981411
create_column_split_tensors();
13991412
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
14001413
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);

0 commit comments

Comments
 (0)