Skip to content

Commit f79f6f3

Browse files
committed
wip
1 parent b133d07 commit f79f6f3

File tree

1 file changed

+105
-6
lines changed

1 file changed

+105
-6
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,6 +1122,72 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11221122
}
11231123
};
11241124

1125+
bool has_init = false;
1126+
auto create_reduce_op_tensors = [&] {
1127+
if (has_init) {
1128+
GGML_ABORT("Tensor %s has already been initialized, but is being initialized again.\n", tensor->name);
1129+
}
1130+
has_init = true;
1131+
ggml_backend_tp_finish_init_tensor(tensor);
1132+
1133+
// one of these must be a reduce tensor, the other may be a split tensor, unsplit tensor, or even another reduce tensor.
1134+
auto reduce_tensor = src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE ? src0 : src1;
1135+
auto add_tensor = src0_extra->split_tensors == GGML_TP_SPLIT_REDUCE ? src1 : src0;
1136+
auto add_extra = (ggml_tensor_parallel_extra *)add_tensor->extra;
1137+
auto reduce_extra = (ggml_tensor_parallel_extra *)reduce_tensor->extra;
1138+
1139+
if (add_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1140+
// double reduce add can simply be added without any views.
1141+
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
1142+
auto wrapped = extra->tensors[j];
1143+
auto reduce_op = ggml_backend_tp_clone_tensor(wrapped);
1144+
extra->reduce_op_tensors[j] = reduce_op;
1145+
reduce_op->buffer = wrapped->buffer;
1146+
reduce_op->view_src = wrapped;
1147+
reduce_op->view_offs = 0;
1148+
reduce_op->data = wrapped->data;
1149+
1150+
reduce_op->src[0] = reduce_extra->has_rejoin ? reduce_extra->rejoined_tensor_views[j][j] : reduce_extra->tensors[j];
1151+
reduce_op->src[1] = add_extra->has_rejoin ? add_extra->rejoined_tensor_views[j][j] : add_extra->tensors[j];
1152+
}
1153+
}
1154+
else {
1155+
auto splits = get_col_splits(tensor);
1156+
size_t col_offset = 0;
1157+
1158+
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
1159+
auto wrapped = extra->tensors[j];
1160+
1161+
// create a col split view of the destination that can be used for reduction
1162+
auto reduce_op = ggml_backend_tp_clone_tensor(wrapped);
1163+
extra->reduce_op_tensors[j] = reduce_op;
1164+
reduce_op->buffer = wrapped->buffer;
1165+
reduce_op->view_src = wrapped;
1166+
reduce_op->view_offs = col_offset * wrapped->nb[0];
1167+
reduce_op->data = wrapped->data + reduce_op->view_offs;
1168+
reduce_op->ne[0] = splits.split[j];
1169+
1170+
// the reduce was rejoined, and the
1171+
auto reduce = reduce_extra->tensors[j];
1172+
if (reduce_extra->has_rejoin) {
1173+
reduce = reduce_extra->rejoined_tensor_views[j][j];
1174+
}
1175+
1176+
// create a col split view of the reduced tensor
1177+
ensure_reduce_split_views(reduce_tensor);
1178+
1179+
auto reduce_op_src_view = reduce_extra->reduce_split_views[j];
1180+
reduce_op->src[0] = reduce_op_src_view;
1181+
1182+
auto add = add_extra->split_tensors == GGML_TP_SPLIT_NONE ? add_extra->converted_tensors[j] : add_extra->tensors[j];
1183+
reduce_op->src[1] = add;
1184+
1185+
col_offset += splits.split[j];
1186+
}
1187+
1188+
}
1189+
};
1190+
11251191
auto no_reduce = [&](ggml_tensor *src, ggml_tensor_parallel_extra *src_extra) {
11261192
if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
11271193
ensure_rejoined(tensor, src);
@@ -1151,7 +1217,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11511217
wrapped->src[src_index] = src_extra->tensors[j];
11521218
}
11531219
}
1154-
else if (split == GGML_TP_SPLIT_ROWS || GGML_TP_SPLIT_COLUMNS || GGML_TP_SPLIT_DIM2) {
1220+
else if (split == GGML_TP_SPLIT_ROWS || split == GGML_TP_SPLIT_COLUMNS || split == GGML_TP_SPLIT_DIM2) {
11551221
if (src_extra->split_tensors == GGML_TP_SPLIT_NONE) {
11561222
wrapped->src[src_index] = src_extra->converted_tensors[j];
11571223
}
@@ -1162,12 +1228,38 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
11621228
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src%d is split as %d but requested to be split as %d.\n", tensor->name, ggml_op_name(tensor->op), src_index, src_extra->split_tensors, split);
11631229
}
11641230
}
1231+
else if (split == GGML_TP_SPLIT_REDUCE) {
1232+
if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
1233+
wrapped->src[src_index] = src_extra->tensors[j];
1234+
}
1235+
else if (src_extra->split_tensors) {
1236+
wrapped->src[src_index] = src_extra->reduce_split_views[j];
1237+
}
1238+
else {
1239+
wrapped->src[src_index] = src_extra->tensors[j];
1240+
}
1241+
}
11651242
else {
11661243
GGML_ABORT("NYI\n");
11671244
}
11681245
}
11691246
};
11701247

1248+
auto check_srcs = [&]() {
1249+
if (src0 && !extra->tensors[0]->src[0]) {
1250+
GGML_ABORT("Tensor %s failed to create src0 tensors for tensor parallelism.\n", tensor->name);
1251+
}
1252+
if (src1 && !extra->tensors[0]->src[1]) {
1253+
GGML_ABORT("Tensor %s failed to create src1 tensors for tensor parallelism.\n", tensor->name);
1254+
}
1255+
if (src2 && !extra->tensors[0]->src[2]) {
1256+
GGML_ABORT("Tensor %s failed to create src2 tensors for tensor parallelism.\n", tensor->name);
1257+
}
1258+
if (src3 && !extra->tensors[0]->src[3]) {
1259+
GGML_ABORT("Tensor %s failed to create src3 tensors for tensor parallelism.\n", tensor->name);
1260+
}
1261+
};
1262+
11711263
switch (tensor->op) {
11721264
case GGML_OP_ROPE: {
11731265
no_split_view(src0, src0_extra);
@@ -1229,6 +1321,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12291321
set_src_tensor(2, GGML_TP_SPLIT_COLUMNS);
12301322
set_src_tensor(3, GGML_TP_SPLIT_DIM2);
12311323
}
1324+
check_srcs();
12321325
break;
12331326
}
12341327

@@ -1255,8 +1348,7 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12551348

12561349
if (src0_split_tensors == GGML_TP_SPLIT_REDUCE && src1_split_tensors == GGML_TP_SPLIT_REDUCE) {
12571350
create_reduce_tensors();
1258-
set_src_tensor(0, GGML_TP_SPLIT_REDUCE);
1259-
set_src_tensor(1, GGML_TP_SPLIT_REDUCE);
1351+
create_reduce_op_tensors();
12601352
}
12611353
else if (src0_split_tensors & !src1_split_tensors) {
12621354
auto split_tensors = src0_split_tensors ? src0_split_tensors : src1_split_tensors;
@@ -1276,9 +1368,14 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
12761368
set_src_tensor(1, GGML_TP_SPLIT_ROWS);
12771369
}
12781370
else if (src0_split_tensors == GGML_TP_SPLIT_REDUCE) {
1371+
ensure_reduce_split_views(src1);
1372+
create_reduce_tensors();
1373+
create_reduce_op_tensors();
1374+
}
1375+
else if (src1_split_tensors == GGML_TP_SPLIT_REDUCE) {
1376+
ensure_reduce_split_views(src0);
12791377
create_reduce_tensors();
1280-
set_src_tensor(0, GGML_TP_SPLIT_REDUCE);
1281-
set_src_tensor(1, GGML_TP_SPLIT_REDUCE);
1378+
create_reduce_op_tensors();
12821379
}
12831380
else {
12841381
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split but src1 is not.\n", tensor->name, ggml_op_name(tensor->op));
@@ -1497,7 +1594,9 @@ static void do_init(ggml_tensor * tensor, ggml_tensor_parallel_extra * extra) {
14971594
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism.\n", tensor->name, ggml_op_name(tensor->op));
14981595
}
14991596

1500-
ggml_backend_tp_finish_init_tensor(tensor);
1597+
if (!has_init) {
1598+
ggml_backend_tp_finish_init_tensor(tensor);
1599+
}
15011600
}
15021601

15031602
static enum ggml_status ggml_backend_tp_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {

0 commit comments

Comments
 (0)