Skip to content

Commit ac56fb0

Browse files
committed
fix view snafu
1 parent 3fc764b commit ac56fb0

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ static ggml_tensor* ggml_backend_tp_node_compute_split(int device_index, ggml_te
957957
return reduce_op_tensor;
958958
}
959959

960-
static bool immediate_compute = true;
960+
static bool immediate_compute = false;
961961
static void ggml_backend_tp_buffer_compute_graph(ggml_cgraph * cgraph, std::function<bool(int, std::set<ggml_tensor*>)> gather_pending, std::function<bool(int, ggml_tensor *, ggml_tensor_parallel_extra *)> compute, std::function<void(int, std::set<ggml_tensor*>)> flush_compute) {
962962
std::set<ggml_tensor*> pending_gathers;
963963
for (int node_index = 0; node_index < cgraph->n_nodes; node_index++) {
@@ -1360,14 +1360,14 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
13601360
case GGML_OP_VIEW:
13611361
case GGML_OP_FLASH_ATTN_EXT:
13621362
case GGML_OP_RESHAPE:
1363-
case GGML_OP_PERMUTE:
1363+
// case GGML_OP_PERMUTE:
13641364
case GGML_OP_MUL:
13651365
case GGML_OP_MUL_MAT:
13661366
force_rejoin = false;
13671367
break;
13681368
}
13691369

1370-
if (force_rejoin) {
1370+
if (false) {
13711371
for (int i = 0; i < GGML_MAX_SRC; i++) {
13721372
auto src = tensor->src[i];
13731373
if (!src) {
@@ -1473,10 +1473,17 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
14731473
GGML_ABORT("Tensor %s has unsupported op %s for tensor parallelism, src0 is split as view but not evenly divisible by the rope head count.\n", tensor->name, ggml_op_name(tensor->op));
14741474
}
14751475

1476-
// similar to ROPE above, the input must be row split and becomes column split.
1477-
src0_split_tensors = GGML_TP_SPLIT_ROWS;
1478-
create_row_split_tensors_for(src0, src0_extra);
1476+
// similar to ROPE above, the input must be dim2 split and becomes row split.
1477+
src0_split_tensors = GGML_TP_SPLIT_DIM2;
1478+
create_dim2_split_tensors_for(src0, src0_extra);
14791479
set_src_tensor_for(src0, src0_extra, 0, GGML_TP_SPLIT_ROWS);
1480+
if (src0->op == GGML_OP_PERMUTE) {
1481+
auto splits = get_dim_splits(src0->nb[1]);
1482+
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
1483+
auto wrapped = src0_extra->tensors[j];
1484+
wrapped->nb[1] = splits.split[j];
1485+
}
1486+
}
14801487
ggml_backend_tp_finish_init_tensor(src0);
14811488
}
14821489

0 commit comments

Comments
 (0)