@@ -1586,7 +1586,7 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
15861586 auto src_extra = (ggml_tensor_parallel_extra *)src->extra ;
15871587 // unless this is an add op, a tensor in a reduced state
15881588 // does not count as a split tensor. it will require a rejoin.
1589- if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE && tensor->op != GGML_OP_ADD) {
1589+ if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE && tensor->op != GGML_OP_ADD && tensor-> op != GGML_OP_GET_ROWS ) {
15901590 ensure_rejoined (tensor, src);
15911591 continue ;
15921592 }
@@ -1674,6 +1674,13 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
16741674 }
16751675 }
16761676 }
1677+ else if (tensor->op == GGML_OP_GET_ROWS) {
1678+ auto src0 = tensor->src [0 ];
1679+ auto src0_extra = (ggml_tensor_parallel_extra *)src0->extra ;
1680+ if (src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && !src0_extra->has_rejoin ) {
1681+ extra->split_tensors = GGML_TP_SPLIT_REDUCE;
1682+ }
1683+ }
16771684
16781685 if (ctx->split ) {
16791686 extra->split_tensors = GGML_TP_SPLIT_ROWS;
@@ -1767,37 +1774,60 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
17671774 }
17681775 }
17691776 else {
1770- auto original_ne0 = wrapped->ne [0 ];
17711777 if (tensor->op == GGML_OP_RESHAPE) {
17721778 // ehhhhh i dunno man.
17731779 // 8192x2x1x1 -> 128x64x2x1
17741780 // 8192 is column partitioned into two 4096x2x1x1
17751781 // 4096 / 128 = 32 so actual reshape result is 128x32x2x1
1776- // auto splits = get_col_splits(tensor);
17771782
17781783 auto src = tensor->src [0 ];
17791784 auto src_extra = (ggml_tensor_parallel_extra *)src->extra ;
1780- auto src_cols = src_extra->tensors [j]->ne [0 ];
1781-
1782- if (src_cols > wrapped->ne [0 ]) {
1783- auto original_ne1 = wrapped->ne [1 ];
1784- wrapped->ne [1 ] = src_cols / wrapped->ne [0 ];
1785- if (wrapped->ne [1 ] >= original_ne1) {
1786- wrapped->ne [2 ] = wrapped->ne [1 ] / original_ne1;
1787- wrapped->ne [1 ] = original_ne1;
1785+
1786+ if (src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1787+ auto src_cols = src_extra->tensors [j]->ne [0 ];
1788+ if (src_cols > wrapped->ne [0 ]) {
1789+ auto original_ne1 = wrapped->ne [1 ];
1790+ wrapped->ne [1 ] = src_cols / wrapped->ne [0 ];
1791+ if (wrapped->ne [1 ] >= original_ne1) {
1792+ wrapped->ne [2 ] = wrapped->ne [1 ] / original_ne1;
1793+ wrapped->ne [1 ] = original_ne1;
1794+ }
1795+ else {
1796+ wrapped->ne [2 ] = tensor->ne [2 ];
1797+ }
1798+ wrapped->ne [3 ] = tensor->ne [3 ];
1799+ if (tensor->ne [3 ] > 1 ) {
1800+ int i = 0 ;
1801+ }
17881802 }
17891803 else {
1790- wrapped->ne [2 ] = tensor->ne [2 ];
1804+ GGML_ABORT (" ggml_backend_tp_buffer_init_tensor: tensor %s has src %s with split columns but src cols %zu < wrapped ne[0] %zu\n " , tensor->name , src->name , src_cols, wrapped->ne [0 ]);
1805+ // auto original_ne0 = wrapped->ne[0];
1806+ // wrapped->ne[0] = src_cols;
1807+ // wrapped->ne[2] = wrapped->ne[2] * original_ne0 / wrapped->ne[0];
17911808 }
1792- wrapped->ne [3 ] = tensor->ne [3 ];
1793- if (tensor->ne [3 ] > 1 ) {
1794- int i = 0 ;
1809+ }
1810+ else if (src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1811+ // 128x64x2x1 -> 8192x2x1x1
1812+ // 64 is row partitioned into two 128x32x2x1
1813+ // 128 * 32 = 4096 (the 2 dim is ignores since that would not be contiguous in memory row-wise)
1814+ // so actual reshape result is 4096x2x1x1
1815+ // interestingly the split type changes to column here.
1816+
1817+ auto src_rows_length = src_extra->tensors [j]->ne [0 ] * src_extra->tensors [j]->ne [1 ];
1818+ if (src_rows_length > tensor->ne [0 ]) {
1819+ // could implement it but no need yet.
1820+ GGML_ABORT (" ggml_backend_tp_buffer_init_tensor: tensor %s has src %s with split rows but src rows length %zu > wrapped ne[0] %zu\n " , tensor->name , src->name , src_rows_length, wrapped->ne [0 ]);
1821+ }
1822+ else {
1823+ wrapped->ne [0 ] = src_rows_length;
1824+ wrapped->ne [1 ] = src_extra->tensors [j]->ne [2 ];
1825+
1826+ extra->split_tensors = GGML_TP_SPLIT_COLUMNS;
17951827 }
17961828 }
17971829 else {
1798- auto original_ne0 = wrapped->ne [0 ];
1799- wrapped->ne [0 ] = src_cols;
1800- wrapped->ne [2 ] = wrapped->ne [2 ] * original_ne0 / wrapped->ne [0 ];
1830+ GGML_ABORT (" ggml_backend_tp_buffer_init_tensor: tensor %s has src %s with unknown split type %d\n " , tensor->name , src->name , src_extra->split_tensors );
18011831 }
18021832
18031833 wrapped->nb [1 ] = wrapped->nb [0 ] * wrapped->ne [0 ];
@@ -1816,7 +1846,6 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
18161846 wrapped->nb [1 ] = src_extra->tensors [j]->nb [1 ];
18171847 wrapped->nb [2 ] = src_extra->tensors [j]->nb [2 ];
18181848 wrapped->nb [3 ] = src_extra->tensors [j]->nb [3 ];
1819- ensure_rejoined (nullptr , tensor);
18201849 }
18211850 else if (tensor->op == GGML_OP_PERMUTE) {
18221851 auto src = tensor->src [0 ];
@@ -1851,6 +1880,7 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
18511880 }
18521881 else {
18531882 if (extra->split_tensors != GGML_TP_SPLIT_REDUCE) {
1883+ auto original_ne0 = wrapped->ne [0 ];
18541884 ggml_split splits = get_col_splits (vs);
18551885
18561886 // update col count
0 commit comments