Skip to content

Commit 8eaa479

Browse files
committed
mulmat id stub
1 parent 412fe47 commit 8eaa479

File tree

1 file changed

+42
-2
lines changed

1 file changed

+42
-2
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,18 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
962962

963963
extra->initialized = true;
964964

965+
// ensure all src are initialized, out of order usage is possible on view tensors.
966+
for (size_t i = 0; i < GGML_MAX_SRC; i++) {
967+
auto src = tensor->src[i];
968+
if (!src) {
969+
break;
970+
}
971+
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
972+
// node index is incorrect here, but need something for debugging.
973+
do_init(node_index, src, src_extra);
974+
}
975+
976+
965977
auto src0 = tensor->src[0];
966978
auto src1 = tensor->src[1];
967979
auto src2 = tensor->src[2];
@@ -1608,7 +1620,6 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
16081620

16091621
GGML_ASSERT(ggml_are_same_shape(extra->tensors[0], extra->tensors[0]->src[0]) && "Tensor parallel tensors must have the same shape.");
16101622
GGML_ASSERT(extra->tensors[0]->ne[0] == extra->tensors[0]->src[0]->ne[0] && "Tensor parallel has incorrect broadcast dimension (ne1).");
1611-
GGML_ASSERT(extra->tensors[0]->ne[0] == extra->tensors[0]->src[1]->ne[0] && "Tensor parallel has incorrect broadcast dimension (ne1).");
16121623
break;
16131624
}
16141625

@@ -1725,7 +1736,7 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
17251736
create_default_tensors();
17261737
set_src_tensor(0, GGML_TP_SPLIT_NONE);
17271738
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1728-
1739+
check_srcs();
17291740
break;
17301741
}
17311742

@@ -1762,6 +1773,31 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
17621773
break;
17631774
}
17641775

1776+
case GGML_OP_SUM_ROWS: {
1777+
no_split_view(src0, src0_extra);
1778+
if (extra->split_tensors == GGML_TP_SPLIT_COLUMNS) {
1779+
create_column_split_tensors();
1780+
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1781+
}
1782+
else {
1783+
ensure_rejoined(tensor, src0);
1784+
create_default_tensors();
1785+
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1786+
}
1787+
check_srcs();
1788+
break;
1789+
}
1790+
case GGML_OP_SOFT_MAX:
1791+
case GGML_OP_ARGSORT: {
1792+
no_split_view(src0, src0_extra);
1793+
1794+
ensure_rejoined(tensor, src0);
1795+
create_default_tensors();
1796+
set_src_tensor(0, GGML_TP_SPLIT_NONE);
1797+
check_srcs();
1798+
break;
1799+
}
1800+
17651801
case GGML_OP_VIEW:
17661802
case GGML_OP_PERMUTE:
17671803
case GGML_OP_RESHAPE: {
@@ -2513,6 +2549,10 @@ static bool ggml_backend_tp_device_supports_op(ggml_backend_dev_t dev, const str
25132549
GGML_UNUSED(dev);
25142550
GGML_UNUSED(op);
25152551

2552+
if (op->op == GGML_OP_MUL_MAT_ID) {
2553+
return false;
2554+
}
2555+
25162556
auto buft = op->buffer ? op->buffer->buft : nullptr;
25172557
if (buft && (!ggml_backend_buft_is_tp_split(buft) && !ggml_backend_buft_is_tp(buft))) {
25182558
return false;

0 commit comments

Comments
 (0)