Skip to content

Commit 2b6050c

Browse files
committed
wip
1 parent a70bbb7 commit 2b6050c

File tree

1 file changed

+36
-35
lines changed

1 file changed

+36
-35
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,7 +1530,42 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
15301530
auto src0_split_tensors = src0_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src0_extra->split_tensors;
15311531
auto src1_split_tensors = src1_extra->has_rejoin ? GGML_TP_SPLIT_NONE : src1_extra->split_tensors;
15321532

1533-
if (src0_split_tensors == GGML_TP_SPLIT_REDUCE && src1_split_tensors == GGML_TP_SPLIT_REDUCE) {
1533+
if (!src0_split_tensors && !src1_split_tensors) {
1534+
// straight add
1535+
ensure_rejoined(tensor, src0);
1536+
ensure_rejoined(tensor, src1);
1537+
create_default_tensors();
1538+
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1539+
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1540+
}
1541+
else if ((src0_split_tensors || src1_split_tensors) && (src1_split_tensors != GGML_TP_SPLIT_REDUCE && src0_split_tensors != GGML_TP_SPLIT_REDUCE)) {
1542+
auto split = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
1543+
if (split == GGML_TP_SPLIT_ROWS) {
1544+
ensure_row_split(src0);
1545+
ensure_row_split(src1);
1546+
create_row_split_tensors();
1547+
set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1548+
set_src_tensor(1, GGML_TP_SPLIT_ROWS);
1549+
}
1550+
else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS) {
1551+
ensure_column_split(src0);
1552+
ensure_column_split(src1);
1553+
create_column_split_tensors();
1554+
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1555+
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
1556+
}
1557+
else if (src0_split_tensors == GGML_TP_SPLIT_DIM2) {
1558+
ensure_dim2_split(src0);
1559+
ensure_dim2_split(src1);
1560+
create_row_split_tensors();
1561+
set_src_tensor(0, GGML_TP_SPLIT_DIM2);
1562+
set_src_tensor(1, GGML_TP_SPLIT_DIM2);
1563+
}
1564+
else {
1565+
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as %d but src1 is split as %d.\n", tensor->name, ggml_op_name(tensor->op), src0_split_tensors, src1_split_tensors);
1566+
}
1567+
}
1568+
else if (src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
15341569
create_reduce_tensors();
15351570
create_reduce_op_tensors();
15361571
}
@@ -1550,47 +1585,13 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
15501585
create_reduce_tensors();
15511586
create_reduce_op_tensors();
15521587
}
1553-
else if (!src0_split_tensors && !src1_split_tensors) {
1554-
ensure_rejoined(tensor, src0);
1555-
ensure_rejoined(tensor, src1);
1556-
create_default_tensors();
1557-
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1558-
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1559-
}
15601588
else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS && !src1_split_tensors) {
15611589
ensure_column_split(src0);
15621590
ensure_column_split(src1);
15631591
create_column_split_tensors();
15641592
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
15651593
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
15661594
}
1567-
else if (src0_split_tensors == src1_split_tensors) {
1568-
auto split_tensors = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
1569-
if (split_tensors == GGML_TP_SPLIT_ROWS) {
1570-
ensure_row_split(src0);
1571-
ensure_row_split(src1);
1572-
create_row_split_tensors();
1573-
set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1574-
set_src_tensor(1, GGML_TP_SPLIT_ROWS);
1575-
}
1576-
else if (split_tensors == GGML_TP_SPLIT_COLUMNS) {
1577-
ensure_column_split(src0);
1578-
ensure_column_split(src1);
1579-
create_column_split_tensors();
1580-
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1581-
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
1582-
}
1583-
else if (!split_tensors) {
1584-
ensure_rejoined(tensor, src0);
1585-
ensure_rejoined(tensor, src1);
1586-
create_default_tensors();
1587-
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1588-
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1589-
}
1590-
else {
1591-
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as %d but src1 is split as %d.\n", tensor->name, ggml_op_name(tensor->op), src0_split_tensors, src1_split_tensors);
1592-
}
1593-
}
15941595
else {
15951596
// mismatched, so join both
15961597
ensure_rejoined(tensor, src0);

0 commit comments

Comments
 (0)