@@ -201,7 +201,7 @@ static bool ggml_backend_tp_is_split(ggml_tensor * tensor) {
201201
202202static bool is_split_compatible (ggml_tensor * tensor) {
203203 auto op = tensor->op ;
204- if (op == GGML_OP_MUL_MAT) {
204+ if (op == GGML_OP_MUL_MAT || op == GGML_OP_MUL_MAT_ID ) {
205205 auto src1 = tensor->src [1 ];
206206 if (src1->buffer && ggml_backend_buft_is_tp_split (src1->buffer ->buft )) {
207207 return false ;
@@ -213,6 +213,7 @@ static bool is_split_compatible(ggml_tensor * tensor) {
213213 switch (op) {
214214 case GGML_OP_UNARY:
215215 case GGML_OP_MUL_MAT:
216+ case GGML_OP_MUL_MAT_ID:
216217 case GGML_OP_ADD:
217218 case GGML_OP_SUB:
218219 case GGML_OP_MUL:
@@ -866,7 +867,7 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
866867 // if (device_index == 0) {
867868 // for (int i = 0; i < backend_graph->n_nodes; i++) {
868869 // auto tensor = backend_graph->nodes[i];
869- // printf("TP %d: %s %s %x\n", node_index, ggml_op_name(tensor->op), tensor->name, tensor->data);
870+ // printf("TP %d: %s %s %x\n", node_index - backend_graph->n_nodes + i , ggml_op_name(tensor->op), tensor->name, tensor->data);
870871 // for (int j = 0; j < GGML_MAX_SRC; j++) {
871872 // auto src = tensor->src[j];
872873 // if (!src) {
@@ -992,59 +993,73 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
992993 }
993994 };
994995
995- auto create_row_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
996+ auto prepare_wrapped = [](ggml_tensor * tensor, ggml_tensor * dims) {
997+ auto wrapped = ggml_backend_tp_clone_tensor (dims);
998+ if (dims != tensor) {
999+ wrapped->op = tensor->op ;
1000+ for (uint32_t i = 0 ; i < GGML_MAX_OP_PARAMS / sizeof (int32_t ); i++) {
1001+ wrapped->op_params [i] = tensor->op_params [i];
1002+ }
1003+ }
1004+ return wrapped;
1005+ };
1006+
1007+ auto create_row_split_tensors_for = [prepare_wrapped](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra, ggml_tensor * dims = nullptr ) {
1008+ dims = dims ? dims : tensor;
9961009 extra->split_tensors = GGML_TP_SPLIT_ROWS;
997- auto splits = get_row_splits (tensor );
1010+ auto splits = get_row_splits (dims );
9981011 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
9991012 auto dev = ggml_parallel_devices[j];
1000- auto wrapped = ggml_backend_tp_clone_tensor (tensor);
1013+ auto wrapped = prepare_wrapped (tensor, dims );
10011014 extra->tensors [j] = wrapped;
10021015
10031016 // update row count
10041017 wrapped->ne [1 ] = splits.split [j];
10051018 // adjust the stride for the new row count
1006- wrapped->nb [2 ] = wrapped->nb [2 ] / tensor ->ne [1 ] * splits.split [j];
1007- wrapped->nb [3 ] = wrapped->nb [3 ] / tensor ->ne [1 ] * splits.split [j];
1019+ wrapped->nb [2 ] = wrapped->nb [2 ] / dims ->ne [1 ] * splits.split [j];
1020+ wrapped->nb [3 ] = wrapped->nb [3 ] / dims ->ne [1 ] * splits.split [j];
10081021 }
10091022 };
10101023
10111024 auto create_row_split_tensors = [&]() {
10121025 create_row_split_tensors_for (tensor, extra);
10131026 };
10141027
1015- auto create_column_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
1028+ auto create_column_split_tensors_for = [prepare_wrapped](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra, ggml_tensor * dims = nullptr ) {
1029+ dims = dims ? dims : tensor;
10161030 extra->split_tensors = GGML_TP_SPLIT_COLUMNS;
1017- auto splits = get_col_splits (tensor );
1031+ auto splits = get_col_splits (dims );
10181032 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
10191033 auto dev = ggml_parallel_devices[j];
1020- auto wrapped = ggml_backend_tp_clone_tensor (tensor);
1034+ auto wrapped = prepare_wrapped (tensor, dims );
10211035 extra->tensors [j] = wrapped;
10221036
10231037 // update col count
10241038 wrapped->ne [0 ] = splits.split [j];
10251039 // adjust the stride for the new col count
1026- wrapped->nb [1 ] = wrapped->nb [1 ] / tensor ->ne [0 ] * splits.split [j];
1027- wrapped->nb [2 ] = wrapped->nb [2 ] / tensor ->ne [0 ] * splits.split [j];
1028- wrapped->nb [3 ] = wrapped->nb [3 ] / tensor ->ne [0 ] * splits.split [j];
1040+ wrapped->nb [1 ] = wrapped->nb [1 ] / dims ->ne [0 ] * splits.split [j];
1041+ wrapped->nb [2 ] = wrapped->nb [2 ] / dims ->ne [0 ] * splits.split [j];
1042+ wrapped->nb [3 ] = wrapped->nb [3 ] / dims ->ne [0 ] * splits.split [j];
10291043 }
10301044 };
10311045
10321046 auto create_column_split_tensors = [&]() {
10331047 create_column_split_tensors_for (tensor, extra);
10341048 };
10351049
1036- auto create_dim2_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
1050+ auto create_dim2_split_tensors_for = [prepare_wrapped](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra, ggml_tensor * dims = nullptr ) {
1051+ dims = dims ? dims : tensor;
10371052 extra->split_tensors = GGML_TP_SPLIT_DIM2;
1038- auto splits = get_dim_splits (tensor ->ne [2 ]);
1053+ auto splits = get_dim_splits (dims ->ne [2 ]);
10391054 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
10401055 auto dev = ggml_parallel_devices[j];
1041- auto wrapped = ggml_backend_tp_clone_tensor (tensor);
1056+ auto wrapped = prepare_wrapped (tensor, dims );
10421057 extra->tensors [j] = wrapped;
10431058
10441059 // update dim2 count
10451060 wrapped->ne [2 ] = splits.split [j];
10461061 // adjust the stride for the new dim2 count
1047- wrapped->nb [3 ] = wrapped->nb [3 ] / tensor ->ne [2 ] * splits.split [j];
1062+ wrapped->nb [3 ] = wrapped->nb [3 ] / dims ->ne [2 ] * splits.split [j];
10481063 }
10491064 };
10501065
@@ -1202,21 +1217,36 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
12021217 }
12031218 };
12041219
1205- bool force_rejoin = true ;
1206- switch (tensor->op ) {
1207- case GGML_OP_ROPE:
1208- case GGML_OP_ADD:
1209- case GGML_OP_VIEW:
1210- case GGML_OP_FLASH_ATTN_EXT:
1211- case GGML_OP_RESHAPE:
1212- // case GGML_OP_PERMUTE:
1213- case GGML_OP_MUL:
1214- case GGML_OP_MUL_MAT:
1215- force_rejoin = false ;
1216- break ;
1217- }
1220+ auto ensure_init_from_viewsrc = [create_default_tensors_for, create_column_split_tensors_for, create_row_split_tensors_for, create_dim2_split_tensors_for](ggml_tensor * tensor, ggml_tensor_parallel_extra *extra) {
1221+ if (extra->split_tensors != GGML_TP_SPLIT_VIEW) {
1222+ return ;
1223+ }
1224+ auto view_src = tensor->view_src ;
1225+ if (!view_src) {
1226+ return ;
1227+ }
1228+ auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra ;
1229+ if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1230+ create_column_split_tensors_for (tensor, extra, view_src);
1231+ }
1232+ else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1233+ create_row_split_tensors_for (tensor, extra, view_src);
1234+ }
1235+ else if (view_src_extra->split_tensors == GGML_TP_SPLIT_DIM2) {
1236+ create_dim2_split_tensors_for (tensor, extra, view_src);
1237+ }
1238+ else if (view_src_extra->split_tensors == GGML_TP_SPLIT_NONE) {
1239+ create_default_tensors_for (tensor, extra);
1240+ }
1241+ else {
1242+ GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, view_src is split as %d.\n " , tensor->name , ggml_op_name (tensor->op ), view_src_extra->split_tensors );
1243+ }
1244+
1245+ ggml_backend_tp_finish_init_tensor (tensor);
1246+ };
12181247
1219- if (false ) {
1248+ bool force_rejoin = true ;
1249+ if (force_rejoin) {
12201250 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
12211251 auto src = tensor->src [i];
12221252 if (!src) {
@@ -1582,7 +1612,8 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
15821612 break ;
15831613 }
15841614
1585- case GGML_OP_MUL_MAT: {
1615+ case GGML_OP_MUL_MAT:
1616+ case GGML_OP_MUL_MAT_ID: {
15861617 no_split_view (src0, src0_extra);
15871618 if (tensor->view_src ) {
15881619 GGML_ABORT (" Tensor %s has view source tensors, which are not supported for tensor parallelism.\n " , tensor->name );
@@ -1683,76 +1714,17 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
16831714 }
16841715
16851716 case GGML_OP_CPY: {
1686- auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors ;
1687- auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors ;
1688-
1689- if (src1_extra->split_tensors ) {
1690- GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, src1 is split.\n " , tensor->name , ggml_op_name (tensor->op ));
1691- }
1692-
1693- if (!src0_split_tensors) {
1694- create_default_tensors ();
1695- set_src_tensor (0 , GGML_TP_SPLIT_NONE);
1696- set_src_tensor (1 , GGML_TP_SPLIT_NONE);
1697- }
1698- else {
1699- // GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split.\n", tensor->name, ggml_op_name(tensor->op));
1700- auto view_src = src0;
1701- while (view_src->view_src ) {
1702- view_src = view_src->view_src ;
1703- }
1704- auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra ;
1705-
1706- if (ggml_are_same_shape (tensor, view_src)) {
1707- if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1708- create_column_split_tensors_for (src0, src0_extra);
1709- ggml_backend_tp_finish_init_tensor (src0);
1710- ensure_column_split (src1);
1711- create_column_split_tensors ();
1712- set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
1713- set_src_tensor (1 , GGML_TP_SPLIT_COLUMNS);
1714- }
1715- else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1716- create_row_split_tensors_for (src0, src0_extra);
1717- ggml_backend_tp_finish_init_tensor (src0);
1718- ensure_row_split (src1);
1719- create_row_split_tensors ();
1720- set_src_tensor (0 , GGML_TP_SPLIT_ROWS);
1721- set_src_tensor (1 , GGML_TP_SPLIT_ROWS);
1722- }
1723- }
1724- else {
1725- if (src0_extra->split_tensors == GGML_TP_SPLIT_VIEW) {
1726- if (tensor->ne [0 ] == view_src->ne [0 ]) {
1727- if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1728- create_column_split_tensors_for (src0, src0_extra);
1729- ggml_backend_tp_finish_init_tensor (src0);
1730- }
1731- else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1732- create_row_split_tensors_for (src0, src0_extra);
1733- ggml_backend_tp_finish_init_tensor (src0);
1734- }
1735- else {
1736- GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, src0 is split as %d but requested to be split as %d.\n " , tensor->name , ggml_op_name (tensor->op ), src0_extra->split_tensors , GGML_TP_SPLIT_NONE);
1737- }
1738- }
1739- else if (tensor->ne [0 ] > view_src->ne [0 ]) {
1740- create_column_split_tensors_for (src0, src0_extra);
1741- ggml_backend_tp_finish_init_tensor (src0);
1742- }
1743- else {
1744- create_row_split_tensors_for (src0, src0_extra);
1745- ggml_backend_tp_finish_init_tensor (src0);
1746- }
1747- }
1748-
1749- ensure_rejoined (tensor, src0);
1750-
1751- create_default_tensors ();
1752- set_src_tensor (0 , GGML_TP_SPLIT_NONE);
1753- set_src_tensor (1 , GGML_TP_SPLIT_NONE);
1754- }
1755- }
1717+ // the src1 is the destination, and has already been created.
1718+ // it maybe op NONE or op VIEW. without graph introspection.
1719+ // it is possible to use this cpy op to make the src1 tensor tree
1720+ // split, but this is simpler for now.
1721+ ensure_init_from_viewsrc (src0, src0_extra);
1722+ ensure_init_from_viewsrc (src1, src1_extra);
1723+ ensure_rejoined (tensor, src0);
1724+ ensure_rejoined (tensor, src1);
1725+ create_default_tensors ();
1726+ set_src_tensor (0 , GGML_TP_SPLIT_NONE);
1727+ set_src_tensor (1 , GGML_TP_SPLIT_NONE);
17561728
17571729 break ;
17581730 }
@@ -2554,7 +2526,7 @@ static bool ggml_backend_tp_device_supports_op(ggml_backend_dev_t dev, const str
25542526 }
25552527 }
25562528
2557- if (op->op != GGML_OP_MUL_MAT) {
2529+ if (op->op != GGML_OP_MUL_MAT && op-> op != GGML_OP_MUL_MAT_ID ) {
25582530 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
25592531 auto src = op->src [i];
25602532 if (!src) {
0 commit comments