Skip to content

Commit 8ca9825

Browse files
committed
wip
1 parent d772a14 commit 8ca9825

File tree

1 file changed

+114
-19
lines changed

1 file changed

+114
-19
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 114 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ static size_t ggml_backend_tp_buffer_type_get_alignment(ggml_backend_buffer_type
407407
static ggml_status ensure_dim2_split(const ggml_tensor *src) {
408408
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
409409
if (src_extra->split_tensors) {
410-
if (src_extra->split_tensors != GGML_TP_SPLIT_ROWS) {
411-
GGML_ABORT("Tensor %s is already split as %d, but requested to be split as rows.\n", src->name, src_extra->split_tensors);
410+
if (src_extra->split_tensors != GGML_TP_SPLIT_DIM2) {
411+
GGML_ABORT("Tensor %s is already split as %d, but requested to be split as dim2.\n", src->name, src_extra->split_tensors);
412412
}
413413
// this tensor is already split, so we don't need to do anything
414414
return GGML_STATUS_SUCCESS;
@@ -651,6 +651,10 @@ static void ensure_reduce_split_views(const ggml_tensor *tensor) {
651651
}
652652

653653
static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src) {
654+
auto ctx = (ggml_backend_tp_buffer_context *)src->buffer->context;
655+
if (ctx->split) {
656+
GGML_ABORT("Tensor %s is a split buffer, but rejoin requested.\n", src->name);
657+
}
654658
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
655659
if (!src_extra->split_tensors) {
656660
// this tensor is not split, so we can't rejoin it
@@ -675,8 +679,6 @@ static void ensure_rejoined(const ggml_tensor *reason, const ggml_tensor * src)
675679

676680
const auto alignment = ggml_backend_tp_buffer_type_get_alignment(src->buffer->buft);
677681

678-
auto ctx = (ggml_backend_tp_buffer_context *)src->buffer->context;
679-
680682
auto reduce_scale = src_extra->split_tensors == GGML_TP_SPLIT_REDUCE ? ggml_parallel_devices.size() : 1;
681683

682684
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
@@ -1110,7 +1112,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11101112
create_row_split_tensors_for(tensor, extra);
11111113
};
11121114

1113-
auto create_column_split_tensors = [&]() {
1115+
auto create_column_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11141116
extra->split_tensors = GGML_TP_SPLIT_COLUMNS;
11151117
auto splits = get_col_splits(tensor);
11161118
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
@@ -1127,6 +1129,25 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11271129
}
11281130
};
11291131

1132+
auto create_column_split_tensors = [&]() {
1133+
create_column_split_tensors_for(tensor, extra);
1134+
};
1135+
1136+
auto create_dim2_split_tensors_for = [](ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
1137+
extra->split_tensors = GGML_TP_SPLIT_DIM2;
1138+
auto splits = get_dim_splits(tensor->ne[2]);
1139+
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
1140+
auto dev = ggml_parallel_devices[j];
1141+
auto wrapped = ggml_backend_tp_clone_tensor(tensor);
1142+
extra->tensors[j] = wrapped;
1143+
1144+
// update dim2 count
1145+
wrapped->ne[2] = splits.split[j];
1146+
// adjust the stride for the new dim2 count
1147+
wrapped->nb[3] = wrapped->nb[3] / tensor->ne[2] * splits.split[j];
1148+
}
1149+
};
1150+
11301151
bool has_init = false;
11311152
auto create_reduce_op_tensors = [&] {
11321153
if (has_init) {
@@ -1271,6 +1292,31 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12711292
}
12721293
};
12731294

1295+
bool force_rejoin = true;
1296+
switch (tensor->op) {
1297+
case GGML_OP_MUL_MAT:
1298+
force_rejoin = false;
1299+
break;
1300+
}
1301+
1302+
if (force_rejoin) {
1303+
for (int i = 0; i < GGML_MAX_SRC; i++) {
1304+
auto src = tensor->src[i];
1305+
if (!src) {
1306+
break;
1307+
}
1308+
auto ctx = (ggml_backend_tp_buffer_context *)src->buffer->context;
1309+
if (ctx->split) {
1310+
// this tensor is not split, so we can not rejoin it
1311+
continue;
1312+
}
1313+
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
1314+
if (src_extra->split_tensors) {
1315+
ensure_rejoined(tensor, src);
1316+
}
1317+
}
1318+
}
1319+
12741320
switch (tensor->op) {
12751321
case GGML_OP_ROPE: {
12761322
if (tensor->view_src) {
@@ -1283,6 +1329,11 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12831329
// make this into columns and create views into it
12841330
auto src0_viewsrc = src0->view_src;
12851331
auto src0_viewsrc_extra = (ggml_tensor_parallel_extra *)src0_viewsrc->extra;
1332+
1333+
if (src0_extra->tensors[0]) {
1334+
GGML_ABORT("Reshape Tensor %s has already been initialized, but is being initialized again.\n", tensor->name);
1335+
}
1336+
12861337
if (src0_viewsrc_extra->split_tensors != GGML_TP_SPLIT_COLUMNS) {
12871338
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as columns.\n", tensor->name, ggml_op_name(tensor->op));
12881339
}
@@ -1291,10 +1342,6 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12911342
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as view but not evenly divisible by the rope head count.\n", tensor->name, ggml_op_name(tensor->op));
12921343
}
12931344

1294-
if (src0_extra->tensors[0]) {
1295-
GGML_ABORT("Reshape Tensor %s has already been initialized, but is being initialized again.\n", tensor->name);
1296-
}
1297-
12981345
// rope input is initially on columns.
12991346
// input to rope is split [8192,1,1,1], per gpu it is [4096,1,1,1]
13001347
// the input is then reshaped [128,64,1,1] per gpu it is [128,32,1,1]
@@ -1303,6 +1350,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
13031350
// this means that the rope output is now split on rows.
13041351
src0_split_tensors = GGML_TP_SPLIT_ROWS;
13051352
create_row_split_tensors_for(src0, src0_extra);
1353+
ggml_backend_tp_finish_init_tensor(src0);
13061354
}
13071355

13081356
if (!src0_split_tensors) {
@@ -1328,7 +1376,6 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
13281376
}
13291377

13301378
case GGML_OP_FLASH_ATTN_EXT: {
1331-
no_split_view(src0, src0_extra);
13321379
no_split_view(src1, src1_extra);
13331380
no_split_view(src2, src2_extra);
13341381
no_split_view(src3, src3_extra);
@@ -1340,6 +1387,29 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
13401387
auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors;
13411388
auto src2_split_tensors = src2_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src2_extra->split_tensors;
13421389

1390+
if (src0_split_tensors == GGML_TP_SPLIT_VIEW) {
1391+
// make this into columns and create views into it
1392+
auto src0_viewsrc = src0->view_src;
1393+
auto src0_viewsrc_extra = (ggml_tensor_parallel_extra *)src0_viewsrc->extra;
1394+
1395+
if (src0_extra->tensors[0]) {
1396+
GGML_ABORT("Reshape Tensor %s has already been initialized, but is being initialized again.\n", tensor->name);
1397+
}
1398+
1399+
if (src0_viewsrc_extra->split_tensors != GGML_TP_SPLIT_ROWS) {
1400+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as rows.\n", tensor->name, ggml_op_name(tensor->op));
1401+
}
1402+
1403+
if (tensor->ne[1] % src0_viewsrc->ne[1]) {
1404+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as view but not evenly divisible by the rope head count.\n", tensor->name, ggml_op_name(tensor->op));
1405+
}
1406+
1407+
// similar to ROPE above, the input must be row split and becomes column split.
1408+
src0_split_tensors = GGML_TP_SPLIT_DIM2;
1409+
create_dim2_split_tensors_for(src0, src0_extra);
1410+
ggml_backend_tp_finish_init_tensor(src0);
1411+
}
1412+
13431413
if (!src0_split_tensors && !src1_split_tensors && !src2_split_tensors) {
13441414
create_default_tensors();
13451415
set_src_tensor(0, GGML_TP_SPLIT_NONE);
@@ -1348,14 +1418,14 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
13481418
set_src_tensor(3, GGML_TP_SPLIT_NONE);
13491419
}
13501420
else {
1351-
ensure_row_split(src0);
1421+
ensure_dim2_split(src0);
13521422
ensure_dim2_split(src1);
13531423
ensure_dim2_split(src2);
1354-
create_column_split_tensors();
1355-
set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1356-
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
1357-
set_src_tensor(2, GGML_TP_SPLIT_COLUMNS);
1358-
set_src_tensor(3, GGML_TP_SPLIT_DIM2);
1424+
create_row_split_tensors();
1425+
set_src_tensor(0, GGML_TP_SPLIT_DIM2);
1426+
set_src_tensor(1, GGML_TP_SPLIT_DIM2);
1427+
set_src_tensor(2, GGML_TP_SPLIT_DIM2);
1428+
set_src_tensor(3, GGML_TP_SPLIT_NONE);
13591429
}
13601430
check_srcs();
13611431
break;
@@ -1384,7 +1454,9 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
13841454
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
13851455
auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors;
13861456

1387-
if (src0_split_tensors == GGML_TP_SPLIT_REDUCE && src1_split_tensors == GGML_TP_SPLIT_REDUCE) {
1457+
// sometimes src0/src1 may be used as input twice, one of which needs a rejoin.
1458+
// so check the native split state rather than the rejoin state here.
1459+
if (src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
13881460
create_reduce_tensors();
13891461
create_reduce_op_tensors();
13901462
}
@@ -1530,7 +1602,6 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15301602

15311603
case GGML_OP_MUL_MAT: {
15321604
no_split_view(src0, src0_extra);
1533-
no_split_view(src1, src1_extra);
15341605
if (tensor->view_src) {
15351606
GGML_ABORT("Tensor %s has view source tensors, which are not supported for tensor parallelism.\n", tensor->name);
15361607
}
@@ -1541,6 +1612,30 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15411612
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
15421613
auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors;
15431614

1615+
1616+
if (src1_split_tensors == GGML_TP_SPLIT_VIEW) {
1617+
// make this into columns and create views into it
1618+
auto src1_viewsrc = src1->view_src;
1619+
auto src1_viewsrc_extra = (ggml_tensor_parallel_extra *)src1_viewsrc->extra;
1620+
1621+
if (src1_extra->tensors[0]) {
1622+
GGML_ABORT("Reshape Tensor %s has already been initialized, but is being initialized again.\n", tensor->name);
1623+
}
1624+
1625+
if (src1_viewsrc_extra->split_tensors != GGML_TP_SPLIT_ROWS) {
1626+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src1 is split but not as rows.\n", tensor->name, ggml_op_name(tensor->op));
1627+
}
1628+
1629+
if ((src1_viewsrc->ne[0] * src1_viewsrc->ne[1]) % tensor->ne[0]) {
1630+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src1 is split as view but not evenly divisible by the rope head count.\n", tensor->name, ggml_op_name(tensor->op));
1631+
}
1632+
1633+
// similar to ROPE above, the input must be row split and becomes column split.
1634+
src1_split_tensors = GGML_TP_SPLIT_COLUMNS;
1635+
create_column_split_tensors_for(src1, src1_extra);
1636+
ggml_backend_tp_finish_init_tensor(src1);
1637+
}
1638+
15441639
if (!src0_split_tensors && !src1_split_tensors) {
15451640
create_default_tensors();
15461641
set_src_tensor(0, GGML_TP_SPLIT_NONE);
@@ -1649,7 +1744,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
16491744
create_default_tensors();
16501745
}
16511746
else {
1652-
GGML_LOG_WARN("Tensor %s has unsupported op %s for tensor parallelism, src0 is split.\n", tensor->name, ggml_op_name(tensor->op));
1747+
// GGML_LOG_WARN("Tensor %s has unsupported op %s for tensor parallelism, src0 is split.\n", tensor->name, ggml_op_name(tensor->op));
16531748
has_init = true;
16541749
extra->split_tensors = GGML_TP_SPLIT_VIEW;
16551750
}

0 commit comments

Comments
 (0)