@@ -962,6 +962,18 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
962962
963963 extra->initialized = true ;
964964
965+ // ensure all src are initialized, out of order usage is possible on view tensors.
966+ for (size_t i = 0 ; i < GGML_MAX_SRC; i++) {
967+ auto src = tensor->src [i];
968+ if (!src) {
969+ break ;
970+ }
971+ auto src_extra = (ggml_tensor_parallel_extra *)src->extra ;
972+ // node index is incorrect here, but need something for debugging.
973+ do_init (node_index, src, src_extra);
974+ }
975+
976+
965977 auto src0 = tensor->src [0 ];
966978 auto src1 = tensor->src [1 ];
967979 auto src2 = tensor->src [2 ];
@@ -1608,7 +1620,6 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
16081620
16091621 GGML_ASSERT (ggml_are_same_shape (extra->tensors [0 ], extra->tensors [0 ]->src [0 ]) && " Tensor parallel tensors must have the same shape." );
16101622 GGML_ASSERT (extra->tensors [0 ]->ne [0 ] == extra->tensors [0 ]->src [0 ]->ne [0 ] && " Tensor parallel has incorrect broadcast dimension (ne1)." );
1611- GGML_ASSERT (extra->tensors [0 ]->ne [0 ] == extra->tensors [0 ]->src [1 ]->ne [0 ] && " Tensor parallel has incorrect broadcast dimension (ne1)." );
16121623 break ;
16131624 }
16141625
@@ -1725,7 +1736,7 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
17251736 create_default_tensors ();
17261737 set_src_tensor (0 , GGML_TP_SPLIT_NONE);
17271738 set_src_tensor (1 , GGML_TP_SPLIT_NONE);
1728-
1739+ check_srcs ();
17291740 break ;
17301741 }
17311742
@@ -1762,6 +1773,31 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
17621773 break ;
17631774 }
17641775
1776+ case GGML_OP_SUM_ROWS: {
1777+ no_split_view (src0, src0_extra);
1778+ if (extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1779+ create_column_split_tensors ();
1780+ set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
1781+ }
1782+ else {
1783+ ensure_rejoined (tensor, src0);
1784+ create_default_tensors ();
1785+ set_src_tensor (0 , GGML_TP_SPLIT_NONE);
1786+ }
1787+ check_srcs ();
1788+ break ;
1789+ }
1790+ case GGML_OP_SOFT_MAX:
1791+ case GGML_OP_ARGSORT: {
1792+ no_split_view (src0, src0_extra);
1793+
1794+ ensure_rejoined (tensor, src0);
1795+ create_default_tensors ();
1796+ set_src_tensor (0 , GGML_TP_SPLIT_NONE);
1797+ check_srcs ();
1798+ break ;
1799+ }
1800+
17651801 case GGML_OP_VIEW:
17661802 case GGML_OP_PERMUTE:
17671803 case GGML_OP_RESHAPE: {
@@ -2513,6 +2549,10 @@ static bool ggml_backend_tp_device_supports_op(ggml_backend_dev_t dev, const str
25132549 GGML_UNUSED (dev);
25142550 GGML_UNUSED (op);
25152551
2552+ if (op->op == GGML_OP_MUL_MAT_ID) {
2553+ return false ;
2554+ }
2555+
25162556 auto buft = op->buffer ? op->buffer ->buft : nullptr ;
25172557 if (buft && (!ggml_backend_buft_is_tp_split (buft) && !ggml_backend_buft_is_tp (buft))) {
25182558 return false ;
0 commit comments