@@ -1294,6 +1294,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12941294
12951295 bool force_rejoin = true ;
12961296 switch (tensor->op ) {
1297+ case GGML_OP_MUL:
12971298 case GGML_OP_MUL_MAT:
12981299 force_rejoin = false ;
12991300 break ;
@@ -1443,8 +1444,8 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
14431444 check_srcs ();
14441445 break ;
14451446
1446- case GGML_OP_ADD :
1447- case GGML_OP_SUB : {
1447+ case GGML_OP_SUB :
1448+ case GGML_OP_ADD : {
14481449 no_split_view (src0, src0_extra);
14491450 no_split_view (src1, src1_extra);
14501451 if (tensor->view_src ) {
@@ -1454,16 +1455,17 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
14541455 auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors ;
14551456 auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors ;
14561457
1457- // sometimes src0/src1 may be used as input twice, one of which needs a rejoin.
1458- // so check the native split state rather than the rejoin state here.
1459- if (src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1458+ if (src0_split_tensors == GGML_TP_SPLIT_REDUCE && src0_split_tensors == GGML_TP_SPLIT_REDUCE) {
14601459 create_reduce_tensors ();
14611460 create_reduce_op_tensors ();
14621461 }
14631462 else if (!src0_split_tensors && !src1_split_tensors) {
1463+ ensure_rejoined (tensor, src0);
1464+ ensure_rejoined (tensor, src1);
14641465 create_default_tensors ();
14651466 set_src_tensor (0 , GGML_TP_SPLIT_NONE);
14661467 set_src_tensor (1 , GGML_TP_SPLIT_NONE);
1468+ GGML_ASSERT (ggml_are_same_shape (extra->tensors [0 ], extra->tensors [0 ]->src [0 ]) && " Tensor parallel tensors must have the same shape." );
14671469 }
14681470 else if ((src0_split_tensors ^ src1_split_tensors) || (src0_split_tensors == src1_split_tensors)) {
14691471 auto split_tensors = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
@@ -1504,6 +1506,10 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15041506 set_src_tensor (0 , GGML_TP_SPLIT_NONE);
15051507 set_src_tensor (1 , GGML_TP_SPLIT_NONE);
15061508 }
1509+
1510+ GGML_ASSERT (ggml_are_same_shape (extra->tensors [0 ], extra->tensors [0 ]->src [0 ]) && " Tensor parallel tensors must have the same shape." );
1511+ GGML_ASSERT (extra->tensors [0 ]->src [0 ]->ne [0 ] == extra->tensors [0 ]->src [0 ]->ne [0 ] && " Tensor parallel has incorrect broadcast dimension (ne0)." );
1512+ GGML_ASSERT (extra->tensors [0 ]->ne [0 ] == extra->tensors [0 ]->src [1 ]->ne [0 ] && " Tensor parallel has incorrect broadcast dimension (ne1)." );
15071513 break ;
15081514 }
15091515
@@ -1531,6 +1537,8 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15311537 else {
15321538 GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as columns or rows.\n " , tensor->name , ggml_op_name (tensor->op ));
15331539 }
1540+
1541+ GGML_ASSERT (ggml_are_same_shape (extra->tensors [0 ], extra->tensors [0 ]->src [0 ]) && " Tensor parallel tensors must have the same shape." );
15341542 break ;
15351543 }
15361544
@@ -1550,16 +1558,22 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15501558 if (src0_split_tensors == src1_split_tensors) {
15511559 // spltis match
15521560 if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS) {
1561+ ensure_column_split (src0);
1562+ ensure_column_split (src1);
15531563 create_column_split_tensors ();
15541564 set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
15551565 set_src_tensor (1 , GGML_TP_SPLIT_COLUMNS);
15561566 }
15571567 else if (src0_split_tensors == GGML_TP_SPLIT_ROWS) {
1568+ ensure_row_split (src0);
1569+ ensure_row_split (src1);
15581570 create_row_split_tensors ();
15591571 set_src_tensor (0 , GGML_TP_SPLIT_ROWS);
15601572 set_src_tensor (1 , GGML_TP_SPLIT_ROWS);
15611573 }
15621574 else if (src0_split_tensors == GGML_TP_SPLIT_NONE) {
1575+ ensure_rejoined (tensor, src0);
1576+ ensure_rejoined (tensor, src1);
15631577 create_default_tensors ();
15641578 set_src_tensor (0 , GGML_TP_SPLIT_NONE);
15651579 set_src_tensor (1 , GGML_TP_SPLIT_NONE);
@@ -1580,11 +1594,15 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15801594 // one split, one not split
15811595 auto split_tensors = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
15821596 if (split_tensors == GGML_TP_SPLIT_COLUMNS) {
1597+ if (src0_extra->has_rejoin || src1_extra->has_rejoin ) {
1598+ int i = 0 ;
1599+ }
15831600 ensure_column_split (src0);
15841601 ensure_column_split (src1);
15851602 create_column_split_tensors ();
15861603 set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
15871604 set_src_tensor (1 , GGML_TP_SPLIT_COLUMNS);
1605+ // ensure_rejoined(nullptr, tensor);
15881606 }
15891607 else if (split_tensors == GGML_TP_SPLIT_ROWS) {
15901608 ensure_row_split (src0);
@@ -1597,6 +1615,10 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15971615 GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, src0 is split but src1 is not.\n " , tensor->name , ggml_op_name (tensor->op ));
15981616 }
15991617 }
1618+
1619+ GGML_ASSERT (ggml_are_same_shape (extra->tensors [0 ], extra->tensors [0 ]->src [0 ]) && " Tensor parallel tensors must have the same shape." );
1620+ GGML_ASSERT (extra->tensors [0 ]->ne [0 ] == extra->tensors [0 ]->src [0 ]->ne [0 ] && " Tensor parallel has incorrect broadcast dimension (ne1)." );
1621+ GGML_ASSERT (extra->tensors [0 ]->ne [0 ] == extra->tensors [0 ]->src [1 ]->ne [0 ] && " Tensor parallel has incorrect broadcast dimension (ne1)." );
16001622 break ;
16011623 }
16021624
@@ -1648,12 +1670,25 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
16481670 set_src_tensor (1 , GGML_TP_SPLIT_NONE);
16491671 }
16501672 else if (src0_ctx->split && src1_split_tensors == GGML_TP_SPLIT_COLUMNS) {
1651- // a weight matrix is multiplied by a column split tensor (prior to ROPE), it can be massaged to a column split.
1652- // this results in a reduce split.
1653- ensure_weight_column_split (src0);
1654- create_reduce_tensors ();
1655- set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
1656- set_src_tensor (1 , GGML_TP_SPLIT_COLUMNS);
1673+ if (false ){
1674+ // not possible
1675+ auto src0_ctx = (ggml_backend_tp_buffer_context *)src0->buffer ->context ;
1676+ if (!src0_ctx->split ) {
1677+ ensure_rejoined (tensor, src0);
1678+ }
1679+ ensure_rejoined (tensor, src1);
1680+ create_column_split_tensors ();
1681+ set_src_tensor (0 , GGML_TP_SPLIT_ROWS);
1682+ set_src_tensor (1 , GGML_TP_SPLIT_NONE);
1683+ }
1684+ else {
1685+ // a weight matrix is multiplied by a column split tensor (prior to ROPE), it can be massaged to a column split.
1686+ // this results in a reduce split.
1687+ ensure_weight_column_split (src0);
1688+ create_reduce_tensors ();
1689+ set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
1690+ set_src_tensor (1 , GGML_TP_SPLIT_COLUMNS);
1691+ }
16571692 }
16581693 else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS && src1_split_tensors == GGML_TP_SPLIT_COLUMNS) {
16591694 // technically supported like the weights above, but not expected.
@@ -1680,6 +1715,10 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
16801715 else {
16811716 GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism.\n " , tensor->name , ggml_op_name (tensor->op ));
16821717 }
1718+
1719+ GGML_ASSERT (extra->tensors [0 ]->src [0 ]->ne [0 ] == extra->tensors [0 ]->src [0 ]->ne [0 ] && " Tensor parallel tensors must have the same inner dimension (ne0)." );
1720+ GGML_ASSERT (extra->tensors [0 ]->ne [0 ] == extra->tensors [0 ]->src [0 ]->ne [1 ] && " Tensor parallel has incorrect outer dimension (ne0)." );
1721+ GGML_ASSERT (extra->tensors [0 ]->ne [1 ] == extra->tensors [0 ]->src [1 ]->ne [1 ] && " Tensor parallel has incorrect outer dimension (ne1)." );
16831722 break ;
16841723 }
16851724
0 commit comments