Skip to content

Commit 2e3ce66

Browse files
committed
fix fa rejoin
1 parent fce09c6 commit 2e3ce66

File tree

1 file changed

+49
-19
lines changed

1 file changed

+49
-19
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,7 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
15861586
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
15871587
// unless this is an add op, a tensor in a reduced state
15881588
// does not count as a split tensor. it will require a rejoin.
1589-
if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE && tensor->op != GGML_OP_ADD) {
1589+
if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE && tensor->op != GGML_OP_ADD && tensor->op != GGML_OP_GET_ROWS) {
15901590
ensure_rejoined(tensor, src);
15911591
continue;
15921592
}
@@ -1674,6 +1674,13 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
16741674
}
16751675
}
16761676
}
1677+
else if (tensor->op == GGML_OP_GET_ROWS) {
1678+
auto src0 = tensor->src[0];
1679+
auto src0_extra = (ggml_tensor_parallel_extra *)src0->extra;
1680+
if (src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && !src0_extra->has_rejoin) {
1681+
extra->split_tensors = GGML_TP_SPLIT_REDUCE;
1682+
}
1683+
}
16771684

16781685
if (ctx->split) {
16791686
extra->split_tensors = GGML_TP_SPLIT_ROWS;
@@ -1767,37 +1774,60 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
17671774
}
17681775
}
17691776
else {
1770-
auto original_ne0 = wrapped->ne[0];
17711777
if (tensor->op == GGML_OP_RESHAPE) {
17721778
// ehhhhh i dunno man.
17731779
// 8192x2x1x1 -> 128x64x2x1
17741780
// 8192 is column partitioned into two 4096x2x1x1
17751781
// 4096 / 128 = 32 so actual reshape result is 128x32x2x1
1776-
// auto splits = get_col_splits(tensor);
17771782

17781783
auto src = tensor->src[0];
17791784
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
1780-
auto src_cols = src_extra->tensors[j]->ne[0];
1781-
1782-
if (src_cols > wrapped->ne[0]) {
1783-
auto original_ne1 = wrapped->ne[1];
1784-
wrapped->ne[1] = src_cols / wrapped->ne[0];
1785-
if (wrapped->ne[1] >= original_ne1) {
1786-
wrapped->ne[2] = wrapped->ne[1] / original_ne1;
1787-
wrapped->ne[1] = original_ne1;
1785+
1786+
if (src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1787+
auto src_cols = src_extra->tensors[j]->ne[0];
1788+
if (src_cols > wrapped->ne[0]) {
1789+
auto original_ne1 = wrapped->ne[1];
1790+
wrapped->ne[1] = src_cols / wrapped->ne[0];
1791+
if (wrapped->ne[1] >= original_ne1) {
1792+
wrapped->ne[2] = wrapped->ne[1] / original_ne1;
1793+
wrapped->ne[1] = original_ne1;
1794+
}
1795+
else {
1796+
wrapped->ne[2] = tensor->ne[2];
1797+
}
1798+
wrapped->ne[3] = tensor->ne[3];
1799+
if (tensor->ne[3] > 1) {
1800+
int i = 0;
1801+
}
17881802
}
17891803
else {
1790-
wrapped->ne[2] = tensor->ne[2];
1804+
GGML_ABORT("ggml_backend_tp_buffer_init_tensor: tensor %s has src %s with split columns but src cols %zu < wrapped ne[0] %zu\n", tensor->name, src->name, src_cols, wrapped->ne[0]);
1805+
// auto original_ne0 = wrapped->ne[0];
1806+
// wrapped->ne[0] = src_cols;
1807+
// wrapped->ne[2] = wrapped->ne[2] * original_ne0 / wrapped->ne[0];
17911808
}
1792-
wrapped->ne[3] = tensor->ne[3];
1793-
if (tensor->ne[3] > 1) {
1794-
int i = 0;
1809+
}
1810+
else if (src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1811+
// 128x64x2x1 -> 8192x2x1x1
1812+
// 64 is row partitioned into two 128x32x2x1
1813+
// 128 * 32 = 4096 (the 2 dim is ignores since that would not be contiguous in memory row-wise)
1814+
// so actual reshape result is 4096x2x1x1
1815+
// interestingly the split type changes to column here.
1816+
1817+
auto src_rows_length = src_extra->tensors[j]->ne[0] * src_extra->tensors[j]->ne[1];
1818+
if (src_rows_length > tensor->ne[0]) {
1819+
// could implement it but no need yet.
1820+
GGML_ABORT("ggml_backend_tp_buffer_init_tensor: tensor %s has src %s with split rows but src rows length %zu > wrapped ne[0] %zu\n", tensor->name, src->name, src_rows_length, wrapped->ne[0]);
1821+
}
1822+
else {
1823+
wrapped->ne[0] = src_rows_length;
1824+
wrapped->ne[1] = src_extra->tensors[j]->ne[2];
1825+
1826+
extra->split_tensors = GGML_TP_SPLIT_COLUMNS;
17951827
}
17961828
}
17971829
else {
1798-
auto original_ne0 = wrapped->ne[0];
1799-
wrapped->ne[0] = src_cols;
1800-
wrapped->ne[2] = wrapped->ne[2] * original_ne0 / wrapped->ne[0];
1830+
GGML_ABORT("ggml_backend_tp_buffer_init_tensor: tensor %s has src %s with unknown split type %d\n", tensor->name, src->name, src_extra->split_tensors);
18011831
}
18021832

18031833
wrapped->nb[1] = wrapped->nb[0] * wrapped->ne[0];
@@ -1816,7 +1846,6 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
18161846
wrapped->nb[1] = src_extra->tensors[j]->nb[1];
18171847
wrapped->nb[2] = src_extra->tensors[j]->nb[2];
18181848
wrapped->nb[3] = src_extra->tensors[j]->nb[3];
1819-
ensure_rejoined(nullptr, tensor);
18201849
}
18211850
else if (tensor->op == GGML_OP_PERMUTE) {
18221851
auto src = tensor->src[0];
@@ -1851,6 +1880,7 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
18511880
}
18521881
else {
18531882
if (extra->split_tensors != GGML_TP_SPLIT_REDUCE) {
1883+
auto original_ne0 = wrapped->ne[0];
18541884
ggml_split splits = get_col_splits(vs);
18551885

18561886
// update col count

0 commit comments

Comments
 (0)