Skip to content

Commit 20dc0d1

Browse files
committed
wip
1 parent 8ca9825 commit 20dc0d1

File tree

1 file changed

+50
-11
lines changed

1 file changed

+50
-11
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12941294

12951295
bool force_rejoin = true;
12961296
switch (tensor->op) {
1297+
case GGML_OP_MUL:
12971298
case GGML_OP_MUL_MAT:
12981299
force_rejoin = false;
12991300
break;
@@ -1443,8 +1444,8 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
14431444
check_srcs();
14441445
break;
14451446

1446-
case GGML_OP_ADD:
1447-
case GGML_OP_SUB: {
1447+
case GGML_OP_SUB:
1448+
case GGML_OP_ADD: {
14481449
no_split_view(src0, src0_extra);
14491450
no_split_view(src1, src1_extra);
14501451
if (tensor->view_src) {
@@ -1454,16 +1455,17 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
14541455
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
14551456
auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors;
14561457

1457-
// sometimes src0/src1 may be used as input twice, one of which needs a rejoin.
1458-
// so check the native split state rather than the rejoin state here.
1459-
if (src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1458+
if (src0_split_tensors == GGML_TP_SPLIT_REDUCE && src0_split_tensors == GGML_TP_SPLIT_REDUCE) {
14601459
create_reduce_tensors();
14611460
create_reduce_op_tensors();
14621461
}
14631462
else if (!src0_split_tensors && !src1_split_tensors) {
1463+
ensure_rejoined(tensor, src0);
1464+
ensure_rejoined(tensor, src1);
14641465
create_default_tensors();
14651466
set_src_tensor(0, GGML_TP_SPLIT_NONE);
14661467
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1468+
GGML_ASSERT(ggml_are_same_shape(extra->tensors[0], extra->tensors[0]->src[0]) && "Tensor parallel tensors must have the same shape.");
14671469
}
14681470
else if ((src0_split_tensors ^ src1_split_tensors) || (src0_split_tensors == src1_split_tensors)) {
14691471
auto split_tensors = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
@@ -1504,6 +1506,10 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15041506
set_src_tensor(0, GGML_TP_SPLIT_NONE);
15051507
set_src_tensor(1, GGML_TP_SPLIT_NONE);
15061508
}
1509+
1510+
GGML_ASSERT(ggml_are_same_shape(extra->tensors[0], extra->tensors[0]->src[0]) && "Tensor parallel tensors must have the same shape.");
1511+
GGML_ASSERT(extra->tensors[0]->src[0]->ne[0] == extra->tensors[0]->src[0]->ne[0] && "Tensor parallel has incorrect broadcast dimension (ne0).");
1512+
GGML_ASSERT(extra->tensors[0]->ne[0] == extra->tensors[0]->src[1]->ne[0] && "Tensor parallel has incorrect broadcast dimension (ne1).");
15071513
break;
15081514
}
15091515

@@ -1531,6 +1537,8 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15311537
else {
15321538
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but not as columns or rows.\n", tensor->name, ggml_op_name(tensor->op));
15331539
}
1540+
1541+
GGML_ASSERT(ggml_are_same_shape(extra->tensors[0], extra->tensors[0]->src[0]) && "Tensor parallel tensors must have the same shape.");
15341542
break;
15351543
}
15361544

@@ -1550,16 +1558,22 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15501558
if (src0_split_tensors == src1_split_tensors) {
15511559
// spltis match
15521560
if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS) {
1561+
ensure_column_split(src0);
1562+
ensure_column_split(src1);
15531563
create_column_split_tensors();
15541564
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
15551565
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
15561566
}
15571567
else if (src0_split_tensors == GGML_TP_SPLIT_ROWS) {
1568+
ensure_row_split(src0);
1569+
ensure_row_split(src1);
15581570
create_row_split_tensors();
15591571
set_src_tensor(0, GGML_TP_SPLIT_ROWS);
15601572
set_src_tensor(1, GGML_TP_SPLIT_ROWS);
15611573
}
15621574
else if (src0_split_tensors == GGML_TP_SPLIT_NONE) {
1575+
ensure_rejoined(tensor, src0);
1576+
ensure_rejoined(tensor, src1);
15631577
create_default_tensors();
15641578
set_src_tensor(0, GGML_TP_SPLIT_NONE);
15651579
set_src_tensor(1, GGML_TP_SPLIT_NONE);
@@ -1580,11 +1594,15 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15801594
// one split, one not split
15811595
auto split_tensors = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
15821596
if (split_tensors == GGML_TP_SPLIT_COLUMNS) {
1597+
if (src0_extra->has_rejoin || src1_extra->has_rejoin) {
1598+
int i = 0;
1599+
}
15831600
ensure_column_split(src0);
15841601
ensure_column_split(src1);
15851602
create_column_split_tensors();
15861603
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
15871604
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
1605+
// ensure_rejoined(nullptr, tensor);
15881606
}
15891607
else if (split_tensors == GGML_TP_SPLIT_ROWS) {
15901608
ensure_row_split(src0);
@@ -1597,6 +1615,10 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
15971615
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but src1 is not.\n", tensor->name, ggml_op_name(tensor->op));
15981616
}
15991617
}
1618+
1619+
GGML_ASSERT(ggml_are_same_shape(extra->tensors[0], extra->tensors[0]->src[0]) && "Tensor parallel tensors must have the same shape.");
1620+
GGML_ASSERT(extra->tensors[0]->ne[0] == extra->tensors[0]->src[0]->ne[0] && "Tensor parallel has incorrect broadcast dimension (ne1).");
1621+
GGML_ASSERT(extra->tensors[0]->ne[0] == extra->tensors[0]->src[1]->ne[0] && "Tensor parallel has incorrect broadcast dimension (ne1).");
16001622
break;
16011623
}
16021624

@@ -1648,12 +1670,25 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
16481670
set_src_tensor(1, GGML_TP_SPLIT_NONE);
16491671
}
16501672
else if (src0_ctx->split && src1_split_tensors == GGML_TP_SPLIT_COLUMNS) {
1651-
// a weight matrix is multiplied by a column split tensor (prior to ROPE), it can be massaged to a column split.
1652-
// this results in a reduce split.
1653-
ensure_weight_column_split(src0);
1654-
create_reduce_tensors();
1655-
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1656-
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
1673+
if (false ){
1674+
// not possible
1675+
auto src0_ctx = (ggml_backend_tp_buffer_context *)src0->buffer->context;
1676+
if (!src0_ctx->split) {
1677+
ensure_rejoined(tensor, src0);
1678+
}
1679+
ensure_rejoined(tensor, src1);
1680+
create_column_split_tensors();
1681+
set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1682+
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1683+
}
1684+
else {
1685+
// a weight matrix is multiplied by a column split tensor (prior to ROPE), it can be massaged to a column split.
1686+
// this results in a reduce split.
1687+
ensure_weight_column_split(src0);
1688+
create_reduce_tensors();
1689+
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1690+
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
1691+
}
16571692
}
16581693
else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS && src1_split_tensors == GGML_TP_SPLIT_COLUMNS) {
16591694
// technically supported like the weights above, but not expected.
@@ -1680,6 +1715,10 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
16801715
else {
16811716
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism.\n", tensor->name, ggml_op_name(tensor->op));
16821717
}
1718+
1719+
GGML_ASSERT(extra->tensors[0]->src[0]->ne[0] == extra->tensors[0]->src[0]->ne[0] && "Tensor parallel tensors must have the same inner dimension (ne0).");
1720+
GGML_ASSERT(extra->tensors[0]->ne[0] == extra->tensors[0]->src[0]->ne[1] && "Tensor parallel has incorrect outer dimension (ne0).");
1721+
GGML_ASSERT(extra->tensors[0]->ne[1] == extra->tensors[0]->src[1]->ne[1] && "Tensor parallel has incorrect outer dimension (ne1).");
16831722
break;
16841723
}
16851724

0 commit comments

Comments
 (0)