Skip to content

Commit d772a14

Browse files
committed
wip
1 parent ea3fbff commit d772a14

File tree

1 file changed

+79
-19
lines changed

1 file changed

+79
-19
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ enum ggml_tp_split_type {
4747
GGML_TP_SPLIT_COLUMNS = 2,
4848
GGML_TP_SPLIT_REDUCE = 3,
4949
GGML_TP_SPLIT_DIM2 = 4,
50+
GGML_TP_SPLIT_VIEW = 5,
5051
};
5152

5253
struct ggml_tensor_parallel_extra {
@@ -1089,7 +1090,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
10891090
}
10901091
};
10911092

1092-
auto create_row_split_tensors = [&]() {
1093+
auto create_row_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
10931094
extra->split_tensors = GGML_TP_SPLIT_ROWS;
10941095
auto splits = get_row_splits(tensor);
10951096
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
@@ -1105,6 +1106,10 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11051106
}
11061107
};
11071108

1109+
auto create_row_split_tensors = [&]() {
1110+
create_row_split_tensors_for(tensor, extra);
1111+
};
1112+
11081113
auto create_column_split_tensors = [&]() {
11091114
extra->split_tensors = GGML_TP_SPLIT_COLUMNS;
11101115
auto splits = get_col_splits(tensor);
@@ -1212,6 +1217,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12121217

12131218
if (split == GGML_TP_SPLIT_NONE) {
12141219
if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1220+
ensure_reduce_split_views(tensor->src[src_index]);
12151221
wrapped->src[src_index] = src_extra->reduce_split_views[j];
12161222
}
12171223
else if (src_extra->split_tensors) {
@@ -1237,6 +1243,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12371243
wrapped->src[src_index] = src_extra->tensors[j];
12381244
}
12391245
else if (src_extra->split_tensors) {
1246+
ensure_reduce_split_views(tensor->src[src_index]);
12401247
wrapped->src[src_index] = src_extra->reduce_split_views[j];
12411248
}
12421249
else {
@@ -1266,19 +1273,38 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12661273

12671274
switch (tensor->op) {
12681275
case GGML_OP_ROPE: {
1269-
no_split_view(src0, src0_extra);
12701276
if (tensor->view_src) {
12711277
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", tensor->name);
12721278
}
12731279

1274-
if (src0_extra->split_tensors && src0_extra->split_tensors != GGML_TP_SPLIT_COLUMNS) {
1275-
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as columns.\n", tensor->name, ggml_op_name(tensor->op));
1276-
// technically, this is not a problem, but it is not expected.
1277-
// ensure_rejoined(tensor, src0);
1278-
}
1279-
12801280
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
12811281

1282+
if (src0_split_tensors == GGML_TP_SPLIT_VIEW) {
1283+
// make this into columns and create views into it
1284+
auto src0_viewsrc = src0->view_src;
1285+
auto src0_viewsrc_extra = (ggml_tensor_parallel_extra *)src0_viewsrc->extra;
1286+
if (src0_viewsrc_extra->split_tensors != GGML_TP_SPLIT_COLUMNS) {
1287+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as columns.\n", tensor->name, ggml_op_name(tensor->op));
1288+
}
1289+
1290+
if (src0_viewsrc->ne[0] % tensor->ne[0]) {
1291+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as view but not evenly divisible by the rope head count.\n", tensor->name, ggml_op_name(tensor->op));
1292+
}
1293+
1294+
if (src0_extra->tensors[0]) {
1295+
GGML_ABORT("Reshape Tensor %s has already been initialized, but is being initialized again.\n", tensor->name);
1296+
}
1297+
1298+
// rope input is initially on columns.
1299+
// input to rope is split [8192,1,1,1], per gpu it is [4096,1,1,1]
1300+
// the input is then reshaped [128,64,1,1] per gpu it is [128,32,1,1]
1301+
// this effectively splits it on the num heads 64->32 heads.
1302+
// the output from rope is [128,64,1,1] per gpu it is [128,32,1,1]
1303+
// this means that the rope output is now split on rows.
1304+
src0_split_tensors = GGML_TP_SPLIT_ROWS;
1305+
create_row_split_tensors_for(src0, src0_extra);
1306+
}
1307+
12821308
if (!src0_split_tensors) {
12831309
create_default_tensors();
12841310
set_src_tensor(0, GGML_TP_SPLIT_NONE);
@@ -1287,23 +1313,29 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12871313
set_src_tensor(2, GGML_TP_SPLIT_NONE);
12881314
}
12891315
}
1290-
else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS) {
1291-
// rope input is initially on columns.
1292-
// input to rope is split [8192,1,1,1], per gpu it is [4096,1,1,1]
1293-
// rope is then reshaped [128,64,1,1] per gpu it is [128,32,1,1]
1294-
// this effectively splits it on the head dim 64->32 heads.
1295-
// the output from rope is [128,64,1,1] per gpu it is [128,32,1,1]
1296-
// this means that the rope output is now split on rows.
1297-
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as columns.\n", tensor->name, ggml_op_name(tensor->op));
1316+
else if (src0_split_tensors == GGML_TP_SPLIT_ROWS) {
1317+
create_row_split_tensors();
1318+
set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1319+
set_src_tensor(1, GGML_TP_SPLIT_NONE);
12981320
}
12991321
else {
1300-
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as columns.\n", tensor->name, ggml_op_name(tensor->op));
1322+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as rows.\n", tensor->name, ggml_op_name(tensor->op));
1323+
// technically, this is not a problem, but it is not expected.
1324+
// ensure_rejoined(tensor, src0);
13011325
}
13021326

13031327
break;
13041328
}
13051329

13061330
case GGML_OP_FLASH_ATTN_EXT: {
1331+
no_split_view(src0, src0_extra);
1332+
no_split_view(src1, src1_extra);
1333+
no_split_view(src2, src2_extra);
1334+
no_split_view(src3, src3_extra);
1335+
if (tensor->view_src) {
1336+
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", tensor->name);
1337+
}
1338+
13071339
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
13081340
auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors;
13091341
auto src2_split_tensors = src2_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src2_extra->split_tensors;
@@ -1334,6 +1366,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
13341366
if (tensor->view_src) {
13351367
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", tensor->name);
13361368
}
1369+
13371370
ensure_rejoined(tensor, src0);
13381371
create_default_tensors();
13391372
set_src_tensor(0, GGML_TP_SPLIT_NONE);
@@ -1355,7 +1388,12 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
13551388
create_reduce_tensors();
13561389
create_reduce_op_tensors();
13571390
}
1358-
else if (src0_split_tensors & !src1_split_tensors) {
1391+
else if (!src0_split_tensors && !src1_split_tensors) {
1392+
create_default_tensors();
1393+
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1394+
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1395+
}
1396+
else if ((src0_split_tensors ^ src1_split_tensors) || (src0_split_tensors == src1_split_tensors)) {
13591397
auto split_tensors = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
13601398

13611399
if (split_tensors == GGML_TP_SPLIT_COLUMNS) {
@@ -1398,6 +1436,11 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
13981436
}
13991437

14001438
case GGML_OP_UNARY: {
1439+
no_split_view(src0, src0_extra);
1440+
if (tensor->view_src) {
1441+
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", tensor->name);
1442+
}
1443+
14011444
no_reduce(src0, src0_extra);
14021445
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
14031446

@@ -1565,34 +1608,51 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15651608
}
15661609

15671610
case GGML_OP_GET_ROWS: {
1611+
no_split_view(src0, src0_extra);
1612+
no_split_view(src1, src1_extra);
1613+
if (tensor->view_src) {
1614+
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", tensor->name);
1615+
}
1616+
15681617
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
15691618
if (src0_split_tensors == GGML_TP_SPLIT_ROWS) {
15701619
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as rows.\n", tensor->name, ggml_op_name(tensor->op));
15711620
// technically supported, but not expected.
15721621
}
15731622
else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS) {
15741623
create_column_split_tensors();
1624+
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
15751625
}
15761626
else if (!src0_split_tensors) {
15771627
create_default_tensors();
1628+
set_src_tensor(0, GGML_TP_SPLIT_NONE);
15781629
}
15791630
else if (src0_split_tensors == GGML_TP_SPLIT_REDUCE) {
15801631
// this actually works out just fine because the rows can be gotten then added together.
15811632
create_reduce_tensors();
1633+
set_src_tensor(0, GGML_TP_SPLIT_REDUCE);
15821634
}
15831635
else {
15841636
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as columns or rows.\n", tensor->name, ggml_op_name(tensor->op));
15851637
}
1638+
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1639+
check_srcs();
15861640
break;
15871641
}
15881642

15891643
case GGML_OP_VIEW:
15901644
case GGML_OP_PERMUTE:
15911645
case GGML_OP_RESHAPE: {
1646+
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
15921647
// if split, skip, make the downstream op make sense of it, as some graphs combine a bunch of reshapes/permutes/views.
1593-
if (!extra->split_tensors) {
1648+
if (!src0_split_tensors) {
15941649
create_default_tensors();
15951650
}
1651+
else {
1652+
GGML_LOG_WARN("Tensor %s has unsupported op %s for tensor parallelism, src0 is split.\n", tensor->name, ggml_op_name(tensor->op));
1653+
has_init = true;
1654+
extra->split_tensors = GGML_TP_SPLIT_VIEW;
1655+
}
15961656
break;
15971657
}
15981658
default:

0 commit comments

Comments
 (0)