@@ -970,7 +970,7 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
970970 auto src2_extra = src2 ? (ggml_tensor_parallel_extra *)src2->extra : nullptr ;
971971 auto src3_extra = src3 ? (ggml_tensor_parallel_extra *)src3->extra : nullptr ;
972972
973- auto create_default_tensors = [&]( ) {
973+ auto create_default_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra ) {
974974 extra->split_tensors = GGML_TP_SPLIT_NONE;
975975 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
976976 auto dev = ggml_parallel_devices[j];
@@ -979,6 +979,10 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
979979 }
980980 };
981981
982+ auto create_default_tensors = [&]() {
983+ create_default_tensors_for (tensor, extra);
984+ };
985+
982986 auto create_reduce_tensors = [&]() {
983987 extra->split_tensors = GGML_TP_SPLIT_REDUCE;
984988 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
@@ -1354,9 +1358,13 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
13541358 }
13551359
13561360 case GGML_OP_RMS_NORM:
1357- no_split_view (src0, src0_extra);
1358- if (tensor->view_src ) {
1359- GGML_ABORT (" Tensor %s has view source tensors, which are not supported for tensor parallelism.\n " , tensor->name );
1361+ // auto src0_viewsrc = src0->view_src;
1362+ // auto src0_viewsrc_extra = (ggml_tensor_parallel_extra *)src0_viewsrc->extra;
1363+ // no_split_view(src0, src0_extra);
1364+ if (src0->view_src ) {
1365+ ensure_rejoined (tensor, src0->view_src );
1366+ create_default_tensors_for (src0, src0_extra);
1367+ set_src_tensor_for (src0, src0_extra, 0 , GGML_TP_SPLIT_NONE);
13601368 }
13611369
13621370 ensure_rejoined (tensor, src0);
@@ -1394,7 +1402,12 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
13941402 }
13951403 else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS) {
13961404 ensure_column_split (src0);
1397- ensure_column_split (src1);
1405+ if (src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1406+ ensure_reduce_split_views (src1);
1407+ }
1408+ else {
1409+ ensure_column_split (src1);
1410+ }
13981411 create_column_split_tensors ();
13991412 set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
14001413 set_src_tensor (1 , GGML_TP_SPLIT_COLUMNS);
0 commit comments