Skip to content

Commit fce09c6

Browse files
committed
tp get rows
1 parent d4c8306 commit fce09c6

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,15 @@ static void unwrap_tensor(ggml_tensor * tensor, std::set<ggml_tensor *> & tensor
226226
if (i == 1 && tensor->op == GGML_OP_ROPE) {
227227
wrapped->src[i] = src_extra->tensors[j];
228228
}
229+
229230
if (i == 3 && tensor->op == GGML_OP_FLASH_ATTN_EXT) {
230231
wrapped->src[i] = src_extra->tensors[j];
231232
}
232233

234+
if (i == 1 && tensor->op == GGML_OP_GET_ROWS) {
235+
wrapped->src[i] = src_extra->tensors[j];
236+
}
237+
233238
if (tensor->op == GGML_OP_RMS_NORM) {
234239
ggml_tensor * check;
235240
if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
@@ -261,7 +266,7 @@ static void unwrap_tensor(ggml_tensor * tensor, std::set<ggml_tensor *> & tensor
261266
}
262267

263268
if (!wrapped->src[i]) {
264-
GGML_LOG_ERROR("Tensor %s unwrap failure.\n", tensor->name, src->name);
269+
GGML_ABORT("Tensor %s unwrap failure.\n", tensor->name, src->name);
265270
}
266271
}
267272
}
@@ -329,7 +334,7 @@ static bool is_split_compatible(ggml_tensor * tensor) {
329334
case GGML_OP_ROPE:
330335
case GGML_OP_FLASH_ATTN_EXT:
331336
case GGML_OP_CPY:
332-
// case GGML_OP_GET_ROWS:
337+
case GGML_OP_GET_ROWS:
333338
return true;
334339
default:
335340
return false;
@@ -1625,8 +1630,10 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
16251630
auto src0 = tensor->src[0];
16261631
auto src0_extra = (ggml_tensor_parallel_extra *)src0->extra;
16271632
}
1633+
else if (tensor->op == GGML_OP_GET_ROWS) {
1634+
// nothing to split.
1635+
}
16281636
else if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
1629-
int i = 0;
16301637
ensure_row_split(tensor->src[0]);
16311638
ensure_dim2_split(tensor->src[1]);
16321639
ensure_dim2_split(tensor->src[2]);

0 commit comments

Comments
 (0)