Skip to content

Commit d4c8306

Browse files
committed
tp flash attention
1 parent 706fafb commit d4c8306

File tree

1 file changed

+89
-20
lines changed

1 file changed

+89
-20
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum ggml_tp_split_type {
4545
GGML_TP_SPLIT_ROWS = 1,
4646
GGML_TP_SPLIT_COLUMNS = 2,
4747
GGML_TP_SPLIT_REDUCE = 3,
48+
GGML_TP_SPLIT_DIM2 = 4,
4849
};
4950

5051
struct ggml_tensor_parallel_extra {
@@ -225,7 +226,7 @@ static void unwrap_tensor(ggml_tensor * tensor, std::set<ggml_tensor *> & tensor
225226
if (i == 1 && tensor->op == GGML_OP_ROPE) {
226227
wrapped->src[i] = src_extra->tensors[j];
227228
}
228-
if (i > 0 && tensor->op == GGML_OP_FLASH_ATTN_EXT) {
229+
if (i == 3 && tensor->op == GGML_OP_FLASH_ATTN_EXT) {
229230
wrapped->src[i] = src_extra->tensors[j];
230231
}
231232

@@ -326,8 +327,9 @@ static bool is_split_compatible(ggml_tensor * tensor) {
326327
case GGML_OP_RESHAPE:
327328
case GGML_OP_PERMUTE:
328329
case GGML_OP_ROPE:
329-
// case GGML_OP_FLASH_ATTN_EXT:
330-
// case GGML_OP_CPY:
330+
case GGML_OP_FLASH_ATTN_EXT:
331+
case GGML_OP_CPY:
332+
// case GGML_OP_GET_ROWS:
331333
return true;
332334
default:
333335
return false;
@@ -393,9 +395,79 @@ static size_t ggml_backend_tp_buffer_type_get_alignment(ggml_backend_buffer_type
393395
GGML_UNUSED(buft);
394396
}
395397

396-
static ggml_status ensure_split(const ggml_tensor *src) {
398+
static ggml_status ensure_dim2_split(const ggml_tensor *src) {
397399
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
398400
if (src_extra->split_tensors) {
401+
if (src_extra->split_tensors != GGML_TP_SPLIT_ROWS) {
402+
GGML_ABORT("Tensor %s is already split as %d, but requested to be split as rows.\n", src->name, src_extra->split_tensors);
403+
}
404+
// this tensor is already split, so we don't need to do anything
405+
return GGML_STATUS_SUCCESS;
406+
}
407+
if (src_extra->converted_tensors[0]) {
408+
return GGML_STATUS_SUCCESS;
409+
}
410+
411+
// no actual conversion needs to take place, the split tensors can be
412+
// created by using offsets within the original tensor.
413+
auto splits = get_dim_splits(src->ne[2]);
414+
size_t offset = 0;
415+
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
416+
auto split = ggml_backend_tp_clone_tensor(src);
417+
split->op = GGML_OP_NONE;
418+
src_extra->converted_tensors[j] = split;
419+
420+
split->buffer = src_extra->tensors[j]->buffer;
421+
split->data = (char *) src_extra->tensors[j]->data + offset;
422+
423+
// note that only the dimension needs to be changed, retaining the stride allows
424+
// using the original tensor data for the row split.
425+
split->ne[2] = splits.split[j];
426+
427+
offset += src->nb[1] / src->ne[2] * splits.split[j];
428+
}
429+
return GGML_STATUS_SUCCESS;
430+
}
431+
432+
static ggml_status ensure_row_split(const ggml_tensor *src) {
433+
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
434+
if (src_extra->split_tensors) {
435+
if (src_extra->split_tensors != GGML_TP_SPLIT_ROWS) {
436+
GGML_ABORT("Tensor %s is already split as %d, but requested to be split as rows.\n", src->name, src_extra->split_tensors);
437+
}
438+
// this tensor is already split, so we don't need to do anything
439+
return GGML_STATUS_SUCCESS;
440+
}
441+
if (src_extra->converted_tensors[0]) {
442+
return GGML_STATUS_SUCCESS;
443+
}
444+
445+
// no actual conversion needs to take place, the split tensors can be
446+
// created by using offsets within the original tensor.
447+
auto splits = get_row_splits(src);
448+
size_t offset = 0;
449+
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
450+
auto split = ggml_backend_tp_clone_tensor(src);
451+
split->op = GGML_OP_NONE;
452+
src_extra->converted_tensors[j] = split;
453+
454+
split->buffer = src_extra->tensors[j]->buffer;
455+
split->data = (char *) src_extra->tensors[j]->data + offset;
456+
457+
// note that only the dimension needs to be changed, retaining the stride allows
458+
// using the original tensor data for the row split.
459+
split->ne[1] = splits.split[j];
460+
461+
offset += src->nb[0] / src->ne[1] * splits.split[j];
462+
}
463+
}
464+
465+
static ggml_status ensure_column_or_reduce_split(const ggml_tensor *src) {
466+
auto src_extra = (ggml_tensor_parallel_extra *)src->extra;
467+
if (src_extra->split_tensors) {
468+
if (src_extra->split_tensors == GGML_TP_SPLIT_ROWS) {
469+
GGML_ABORT("Tensor %s is already split as %d, but requested to be split as columns.\n", src->name, src_extra->split_tensors);
470+
}
399471
// this tensor is already split, so we don't need to do anything
400472
return GGML_STATUS_SUCCESS;
401473
}
@@ -547,7 +619,7 @@ static ggml_status ensure_rejoined(const ggml_tensor *reason, const ggml_tensor
547619
src_extra->has_rejoin = true;
548620

549621
// if (reason && reason != src) {
550-
// printf("Rejoining tensor for %s %d %d\n", ggml_op_name(reason->op), src->ne[0], src->ne[1]);
622+
// printf("Rejoining tensor for %s %s\n", ggml_op_name(reason->op), ggml_op_name(src->op));
551623
// }
552624

553625
const auto alignment = ggml_backend_tp_buffer_type_get_alignment(src->buffer->buft);
@@ -907,16 +979,6 @@ static void ggml_backend_tp_buffer_graph_compute_one(struct compute_thread * thr
907979
continue;
908980
}
909981

910-
auto vs = tensor;
911-
while (vs->view_src) {
912-
vs = vs->view_src;
913-
}
914-
if ((vs->ne[2] > 1 || vs->ne[3] > 1)) {
915-
if (!ggml_is_contiguous(vs)) {
916-
GGML_ABORT("Tensor %s has more than 2 dimensions, not supported for TP.\n", tensor->name);
917-
}
918-
}
919-
920982
pending_rejoins.insert(tensor);
921983

922984
if (!be->iface.cpy_tensor2d_async) {
@@ -969,7 +1031,7 @@ static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, gg
9691031
std::set<ggml_backend_tp_buffer_context *> contexts;
9701032
for (auto tensor : tensors) {
9711033
auto extra = (ggml_tensor_parallel_extra *)tensor->extra;
972-
if (tensor->view_src && extra->needs_src_rejoin) {
1034+
if (tensor->view_src) {
9731035
auto view_src = tensor->view_src;
9741036
auto view_src_extra = (ggml_tensor_parallel_extra *)view_src->extra;
9751037
if (view_src_extra->has_rejoin) {
@@ -1563,14 +1625,20 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
15631625
auto src0 = tensor->src[0];
15641626
auto src0_extra = (ggml_tensor_parallel_extra *)src0->extra;
15651627
}
1566-
else if (tensor->op != GGML_OP_UNARY && tensor->op != GGML_OP_FLASH_ATTN_EXT && tensor->op != GGML_OP_ROPE) {
1628+
else if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
1629+
int i = 0;
1630+
ensure_row_split(tensor->src[0]);
1631+
ensure_dim2_split(tensor->src[1]);
1632+
ensure_dim2_split(tensor->src[2]);
1633+
}
1634+
else if (tensor->op != GGML_OP_UNARY && tensor->op != GGML_OP_ROPE) {
15671635
// printf("ggml_backend_tp_buffer_init_tensor: splitting tensor %s with op %s\n", tensor->name, ggml_op_name(tensor->op));
15681636
for (int i = 0; i < GGML_MAX_SRC; i++) {
15691637
auto src = tensor->src[i];
15701638
if (!src) {
15711639
break;
15721640
}
1573-
ensure_split(src);
1641+
ensure_column_or_reduce_split(src);
15741642
}
15751643
}
15761644
}
@@ -1585,8 +1653,8 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
15851653
bool can_reduce = (src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && !src0_extra->has_rejoin) || (src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE && !src1_extra->has_rejoin);
15861654
bool double_reduce = src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE && src1_extra->split_tensors == GGML_TP_SPLIT_REDUCE;
15871655
if (can_reduce) {
1588-
ensure_split(src0);
1589-
ensure_split(src1);
1656+
ensure_column_or_reduce_split(src0);
1657+
ensure_column_or_reduce_split(src1);
15901658
extra->split_tensors = GGML_TP_SPLIT_REDUCE;
15911659
split_reduced_add = true;
15921660
}
@@ -1741,6 +1809,7 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
17411809
wrapped->nb[1] = src_extra->tensors[j]->nb[1];
17421810
wrapped->nb[2] = src_extra->tensors[j]->nb[2];
17431811
wrapped->nb[3] = src_extra->tensors[j]->nb[3];
1812+
ensure_rejoined(nullptr, tensor);
17441813
}
17451814
else if (tensor->op == GGML_OP_PERMUTE) {
17461815
auto src = tensor->src[0];

0 commit comments

Comments
 (0)