@@ -1697,105 +1697,63 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
16971697 }
16981698 else {
16991699 // GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split.\n", tensor->name, ggml_op_name(tensor->op));
1700+ auto view_src = src0;
1701+ while (view_src->view_src ) {
1702+ view_src = view_src->view_src ;
1703+ }
1704+ auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra ;
17001705
1701- if (src0_extra->split_tensors == GGML_TP_SPLIT_VIEW) {
1702- auto view_src = src0;
1703- while (view_src->view_src ) {
1704- view_src = view_src->view_src ;
1706+ if (ggml_are_same_shape (tensor, view_src)) {
1707+ if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1708+ create_column_split_tensors_for (src0, src0_extra);
1709+ ggml_backend_tp_finish_init_tensor (src0);
1710+ ensure_column_split (src1);
1711+ create_column_split_tensors ();
1712+ set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
1713+ set_src_tensor (1 , GGML_TP_SPLIT_COLUMNS);
1714+ }
1715+ else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1716+ create_row_split_tensors_for (src0, src0_extra);
1717+ ggml_backend_tp_finish_init_tensor (src0);
1718+ ensure_row_split (src1);
1719+ create_row_split_tensors ();
1720+ set_src_tensor (0 , GGML_TP_SPLIT_ROWS);
1721+ set_src_tensor (1 , GGML_TP_SPLIT_ROWS);
17051722 }
1706- auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra ;
1707- if (tensor->ne [0 ] == view_src->ne [0 ]) {
1708- if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1723+ }
1724+ else {
1725+ if (src0_extra->split_tensors == GGML_TP_SPLIT_VIEW) {
1726+ if (tensor->ne [0 ] == view_src->ne [0 ]) {
1727+ if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1728+ create_column_split_tensors_for (src0, src0_extra);
1729+ ggml_backend_tp_finish_init_tensor (src0);
1730+ }
1731+ else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1732+ create_row_split_tensors_for (src0, src0_extra);
1733+ ggml_backend_tp_finish_init_tensor (src0);
1734+ }
1735+ else {
1736+ GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, src0 is split as %d but requested to be split as %d.\n " , tensor->name , ggml_op_name (tensor->op ), src0_extra->split_tensors , GGML_TP_SPLIT_NONE);
1737+ }
1738+ }
1739+ else if (tensor->ne [0 ] > view_src->ne [0 ]) {
17091740 create_column_split_tensors_for (src0, src0_extra);
17101741 ggml_backend_tp_finish_init_tensor (src0);
17111742 }
1712- else if (view_src_extra-> split_tensors == GGML_TP_SPLIT_ROWS) {
1743+ else {
17131744 create_row_split_tensors_for (src0, src0_extra);
17141745 ggml_backend_tp_finish_init_tensor (src0);
17151746 }
1716- else {
1717- GGML_ABORT (" Tensor %s has unsupported op %s for tensor parallelism, src0 is split as %d but requested to be split as %d.\n " , tensor->name , ggml_op_name (tensor->op ), src0_extra->split_tensors , GGML_TP_SPLIT_NONE);
1718- }
1719- }
1720- else if (tensor->ne [0 ] > view_src->ne [0 ]) {
1721- create_column_split_tensors_for (src0, src0_extra);
1722- ggml_backend_tp_finish_init_tensor (src0);
17231747 }
1724- else {
1725- create_row_split_tensors_for (src0, src0_extra);
1726- ggml_backend_tp_finish_init_tensor (src0);
1727- }
1728- }
1729-
1730- // create_column_split_tensors_for(src0, src0_extra);
1731- // ggml_backend_tp_finish_init_tensor(src0);
1732- ensure_rejoined (tensor, src0);
1733-
1734- create_default_tensors ();
1735- set_src_tensor (0 , GGML_TP_SPLIT_NONE);
1736- set_src_tensor (1 , GGML_TP_SPLIT_NONE);
1737-
17381748
1749+ ensure_rejoined (tensor, src0);
17391750
1740- // auto view_src = src0;
1741- // while (view_src->view_src) {
1742- // view_src = view_src->view_src;
1743- // ensure_rejoined(tensor, view_src);
1744- // }
1745- // auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra;
1746-
1747- // if (src0_extra->split_tensors == GGML_TP_SPLIT_VIEW) {
1748- // if (src0->ne[0] == view_src->ne[0]) {
1749- // if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1750- // create_column_split_tensors_for(src0, src0_extra);
1751- // // set_src_tensor_for(src0, src0_extra, 0, GGML_TP_SPLIT_COLUMNS);
1752- // ggml_backend_tp_finish_init_tensor(src0);
1753- // }
1754- // else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1755- // create_row_split_tensors_for(src0, src0_extra);
1756- // // set_src_tensor_for(src0, src0_extra, 0, GGML_TP_SPLIT_ROWS);
1757- // ggml_backend_tp_finish_init_tensor(src0);
1758- // }
1759- // else {
1760- // GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as %d but requested to be split as %d.\n", tensor->name, ggml_op_name(tensor->op), view_src_extra->split_tensors, GGML_TP_SPLIT_NONE);
1761- // }
1762- // }
1763- // else
1764- // if (src0->ne[0] > view_src->ne[0]) {
1765- // create_column_split_tensors_for(src0, src0_extra);
1766- // // set_src_tensor_for(src0, src0_extra, 0, GGML_TP_SPLIT_COLUMNS);
1767- // ggml_backend_tp_finish_init_tensor(src0);
1768- // }
1769- // else {
1770- // create_row_split_tensors_for(src0, src0_extra);
1771- // // set_src_tensor_for(src0, src0_extra, 0, GGML_TP_SPLIT_ROWS);
1772- // ggml_backend_tp_finish_init_tensor(src0);
1773- // }
1774- // }
1775-
1776-
1777- // if (tensor->ne[0] == view_src->ne[0]) {
1778- // if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1779- // create_column_split_tensors();
1780- // set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1781- // }
1782- // else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1783- // create_row_split_tensors();
1784- // set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1785- // }
1786- // else {
1787- // GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as %d but requested to be split as %d.\n", tensor->name, ggml_op_name(tensor->op), view_src_extra->split_tensors, GGML_TP_SPLIT_NONE);
1788- // }
1789- // }
1790- // else if (tensor->ne[0] > view_src->ne[0]) {
1791- // create_column_split_tensors();
1792- // set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1793- // }
1794- // else {
1795- // create_row_split_tensors();
1796- // set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1797- // }
1751+ create_default_tensors ();
1752+ set_src_tensor (0 , GGML_TP_SPLIT_NONE);
1753+ set_src_tensor (1 , GGML_TP_SPLIT_NONE);
1754+ }
17981755 }
1756+
17991757 break ;
18001758 }
18011759
0 commit comments