Skip to content

Commit 412fe47

Browse files
committed
wip cpy
1 parent 700c742 commit 412fe47

File tree

1 file changed

+75
-103
lines changed

1 file changed

+75
-103
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 75 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ static bool ggml_backend_tp_is_split(ggml_tensor * tensor) {
201201

202202
static bool is_split_compatible(ggml_tensor * tensor) {
203203
auto op = tensor->op;
204-
if (op == GGML_OP_MUL_MAT) {
204+
if (op == GGML_OP_MUL_MAT || op == GGML_OP_MUL_MAT_ID) {
205205
auto src1 = tensor->src[1];
206206
if (src1->buffer && ggml_backend_buft_is_tp_split(src1->buffer->buft)) {
207207
return false;
@@ -213,6 +213,7 @@ static bool is_split_compatible(ggml_tensor * tensor) {
213213
switch (op) {
214214
case GGML_OP_UNARY:
215215
case GGML_OP_MUL_MAT:
216+
case GGML_OP_MUL_MAT_ID:
216217
case GGML_OP_ADD:
217218
case GGML_OP_SUB:
218219
case GGML_OP_MUL:
@@ -866,7 +867,7 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
866867
// if (device_index == 0) {
867868
// for (int i = 0; i < backend_graph->n_nodes; i++) {
868869
// auto tensor = backend_graph->nodes[i];
869-
// printf("TP %d: %s %s %x\n", node_index, ggml_op_name(tensor->op), tensor->name, tensor->data);
870+
// printf("TP %d: %s %s %x\n", node_index - backend_graph->n_nodes + i, ggml_op_name(tensor->op), tensor->name, tensor->data);
870871
// for (int j = 0; j < GGML_MAX_SRC; j++) {
871872
// auto src = tensor->src[j];
872873
// if (!src) {
@@ -992,59 +993,73 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
992993
}
993994
};
994995

995-
auto create_row_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
996+
auto prepare_wrapped = [](ggml_tensor * tensor, ggml_tensor * dims) {
997+
auto wrapped = ggml_backend_tp_clone_tensor(dims);
998+
if (dims != tensor) {
999+
wrapped->op = tensor->op;
1000+
for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
1001+
wrapped->op_params[i] = tensor->op_params[i];
1002+
}
1003+
}
1004+
return wrapped;
1005+
};
1006+
1007+
auto create_row_split_tensors_for = [prepare_wrapped](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra, ggml_tensor * dims = nullptr) {
1008+
dims = dims ? dims : tensor;
9961009
extra->split_tensors = GGML_TP_SPLIT_ROWS;
997-
auto splits = get_row_splits(tensor);
1010+
auto splits = get_row_splits(dims);
9981011
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
9991012
auto dev = ggml_parallel_devices[j];
1000-
auto wrapped = ggml_backend_tp_clone_tensor(tensor);
1013+
auto wrapped = prepare_wrapped(tensor, dims);
10011014
extra->tensors[j] = wrapped;
10021015

10031016
// update row count
10041017
wrapped->ne[1] = splits.split[j];
10051018
// adjust the stride for the new row count
1006-
wrapped->nb[2] = wrapped->nb[2] / tensor->ne[1] * splits.split[j];
1007-
wrapped->nb[3] = wrapped->nb[3] / tensor->ne[1] * splits.split[j];
1019+
wrapped->nb[2] = wrapped->nb[2] / dims->ne[1] * splits.split[j];
1020+
wrapped->nb[3] = wrapped->nb[3] / dims->ne[1] * splits.split[j];
10081021
}
10091022
};
10101023

10111024
auto create_row_split_tensors = [&]() {
10121025
create_row_split_tensors_for(tensor, extra);
10131026
};
10141027

1015-
auto create_column_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
1028+
auto create_column_split_tensors_for = [prepare_wrapped](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra, ggml_tensor * dims = nullptr) {
1029+
dims = dims ? dims : tensor;
10161030
extra->split_tensors = GGML_TP_SPLIT_COLUMNS;
1017-
auto splits = get_col_splits(tensor);
1031+
auto splits = get_col_splits(dims);
10181032
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
10191033
auto dev = ggml_parallel_devices[j];
1020-
auto wrapped = ggml_backend_tp_clone_tensor(tensor);
1034+
auto wrapped = prepare_wrapped(tensor, dims);
10211035
extra->tensors[j] = wrapped;
10221036

10231037
// update col count
10241038
wrapped->ne[0] = splits.split[j];
10251039
// adjust the stride for the new col count
1026-
wrapped->nb[1] = wrapped->nb[1] / tensor->ne[0] * splits.split[j];
1027-
wrapped->nb[2] = wrapped->nb[2] / tensor->ne[0] * splits.split[j];
1028-
wrapped->nb[3] = wrapped->nb[3] / tensor->ne[0] * splits.split[j];
1040+
wrapped->nb[1] = wrapped->nb[1] / dims->ne[0] * splits.split[j];
1041+
wrapped->nb[2] = wrapped->nb[2] / dims->ne[0] * splits.split[j];
1042+
wrapped->nb[3] = wrapped->nb[3] / dims->ne[0] * splits.split[j];
10291043
}
10301044
};
10311045

10321046
auto create_column_split_tensors = [&]() {
10331047
create_column_split_tensors_for(tensor, extra);
10341048
};
10351049

1036-
auto create_dim2_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
1050+
auto create_dim2_split_tensors_for = [prepare_wrapped](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra, ggml_tensor * dims = nullptr) {
1051+
dims = dims ? dims : tensor;
10371052
extra->split_tensors = GGML_TP_SPLIT_DIM2;
1038-
auto splits = get_dim_splits(tensor->ne[2]);
1053+
auto splits = get_dim_splits(dims->ne[2]);
10391054
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
10401055
auto dev = ggml_parallel_devices[j];
1041-
auto wrapped = ggml_backend_tp_clone_tensor(tensor);
1056+
auto wrapped = prepare_wrapped(tensor, dims);
10421057
extra->tensors[j] = wrapped;
10431058

10441059
// update dim2 count
10451060
wrapped->ne[2] = splits.split[j];
10461061
// adjust the stride for the new dim2 count
1047-
wrapped->nb[3] = wrapped->nb[3] / tensor->ne[2] * splits.split[j];
1062+
wrapped->nb[3] = wrapped->nb[3] / dims->ne[2] * splits.split[j];
10481063
}
10491064
};
10501065

@@ -1202,21 +1217,36 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
12021217
}
12031218
};
12041219

1205-
bool force_rejoin = true;
1206-
switch (tensor->op) {
1207-
case GGML_OP_ROPE:
1208-
case GGML_OP_ADD:
1209-
case GGML_OP_VIEW:
1210-
case GGML_OP_FLASH_ATTN_EXT:
1211-
case GGML_OP_RESHAPE:
1212-
// case GGML_OP_PERMUTE:
1213-
case GGML_OP_MUL:
1214-
case GGML_OP_MUL_MAT:
1215-
force_rejoin = false;
1216-
break;
1217-
}
1220+
auto ensure_init_from_viewsrc = [create_default_tensors_for, create_column_split_tensors_for, create_row_split_tensors_for, create_dim2_split_tensors_for](ggml_tensor * tensor, ggml_tensor_parallel_extra *extra) {
1221+
if (extra->split_tensors != GGML_TP_SPLIT_VIEW) {
1222+
return;
1223+
}
1224+
auto view_src = tensor->view_src;
1225+
if (!view_src) {
1226+
return;
1227+
}
1228+
auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra;
1229+
if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1230+
create_column_split_tensors_for(tensor, extra, view_src);
1231+
}
1232+
else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1233+
create_row_split_tensors_for(tensor, extra, view_src);
1234+
}
1235+
else if (view_src_extra->split_tensors == GGML_TP_SPLIT_DIM2) {
1236+
create_dim2_split_tensors_for(tensor, extra, view_src);
1237+
}
1238+
else if (view_src_extra->split_tensors == GGML_TP_SPLIT_NONE) {
1239+
create_default_tensors_for(tensor, extra);
1240+
}
1241+
else {
1242+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, view_src is split as %d.\n", tensor->name, ggml_op_name(tensor->op), view_src_extra->split_tensors);
1243+
}
1244+
1245+
ggml_backend_tp_finish_init_tensor(tensor);
1246+
};
12181247

1219-
if (false) {
1248+
bool force_rejoin = true;
1249+
if (force_rejoin) {
12201250
for (int i = 0; i < GGML_MAX_SRC; i++) {
12211251
auto src = tensor->src[i];
12221252
if (!src) {
@@ -1582,7 +1612,8 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
15821612
break;
15831613
}
15841614

1585-
case GGML_OP_MUL_MAT: {
1615+
case GGML_OP_MUL_MAT:
1616+
case GGML_OP_MUL_MAT_ID: {
15861617
no_split_view(src0, src0_extra);
15871618
if (tensor->view_src) {
15881619
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", tensor->name);
@@ -1683,76 +1714,17 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
16831714
}
16841715

16851716
case GGML_OP_CPY: {
1686-
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
1687-
auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors;
1688-
1689-
if (src1_extra->split_tensors) {
1690-
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src1 is split.\n", tensor->name, ggml_op_name(tensor->op));
1691-
}
1692-
1693-
if (!src0_split_tensors) {
1694-
create_default_tensors();
1695-
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1696-
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1697-
}
1698-
else {
1699-
// GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split.\n", tensor->name, ggml_op_name(tensor->op));
1700-
auto view_src = src0;
1701-
while (view_src->view_src) {
1702-
view_src = view_src->view_src;
1703-
}
1704-
auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra;
1705-
1706-
if (ggml_are_same_shape(tensor, view_src)) {
1707-
if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1708-
create_column_split_tensors_for(src0, src0_extra);
1709-
ggml_backend_tp_finish_init_tensor(src0);
1710-
ensure_column_split(src1);
1711-
create_column_split_tensors();
1712-
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1713-
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
1714-
}
1715-
else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1716-
create_row_split_tensors_for(src0, src0_extra);
1717-
ggml_backend_tp_finish_init_tensor(src0);
1718-
ensure_row_split(src1);
1719-
create_row_split_tensors();
1720-
set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1721-
set_src_tensor(1, GGML_TP_SPLIT_ROWS);
1722-
}
1723-
}
1724-
else {
1725-
if (src0_extra->split_tensors == GGML_TP_SPLIT_VIEW) {
1726-
if (tensor->ne[0] == view_src->ne[0]) {
1727-
if (view_src_extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1728-
create_column_split_tensors_for(src0, src0_extra);
1729-
ggml_backend_tp_finish_init_tensor(src0);
1730-
}
1731-
else if (view_src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
1732-
create_row_split_tensors_for(src0, src0_extra);
1733-
ggml_backend_tp_finish_init_tensor(src0);
1734-
}
1735-
else {
1736-
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as %d but requested to be split as %d.\n", tensor->name, ggml_op_name(tensor->op), src0_extra->split_tensors, GGML_TP_SPLIT_NONE);
1737-
}
1738-
}
1739-
else if (tensor->ne[0] > view_src->ne[0]) {
1740-
create_column_split_tensors_for(src0, src0_extra);
1741-
ggml_backend_tp_finish_init_tensor(src0);
1742-
}
1743-
else {
1744-
create_row_split_tensors_for(src0, src0_extra);
1745-
ggml_backend_tp_finish_init_tensor(src0);
1746-
}
1747-
}
1748-
1749-
ensure_rejoined(tensor, src0);
1750-
1751-
create_default_tensors();
1752-
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1753-
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1754-
}
1755-
}
1717+
// the src1 is the destination, and has already been created.
1718+
// it maybe op NONE or op VIEW. without graph introspection.
1719+
// it is possible to use this cpy op to make the src1 tensor tree
1720+
// split, but this is simpler for now.
1721+
ensure_init_from_viewsrc(src0, src0_extra);
1722+
ensure_init_from_viewsrc(src1, src1_extra);
1723+
ensure_rejoined(tensor, src0);
1724+
ensure_rejoined(tensor, src1);
1725+
create_default_tensors();
1726+
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1727+
set_src_tensor(1, GGML_TP_SPLIT_NONE);
17561728

17571729
break;
17581730
}
@@ -2554,7 +2526,7 @@ static bool ggml_backend_tp_device_supports_op(ggml_backend_dev_t dev, const str
25542526
}
25552527
}
25562528

2557-
if (op->op != GGML_OP_MUL_MAT) {
2529+
if (op->op != GGML_OP_MUL_MAT && op->op != GGML_OP_MUL_MAT_ID) {
25582530
for (int i = 0; i < GGML_MAX_SRC; i++) {
25592531
auto src = op->src[i];
25602532
if (!src) {

0 commit comments

Comments
 (0)