Skip to content

Commit cbe4ada

Browse files
committed
weight fix
1 parent 9a2d9a5 commit cbe4ada

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

ggml/src/ggml-tp/ggml-tp.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,7 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
10651065
dims = dims ? dims : tensor;
10661066
extra->split_tensors = GGML_TP_SPLIT_COLUMNS;
10671067
auto splits = get_col_splits(dims);
1068+
auto offset_splits = get_dim_splits(tensor->view_offs);
10681069
for (size_t j = 0; j < ggml_parallel_devices.size(); j++) {
10691070
auto wrapped = prepare_wrapped(tensor, dims, offset_aware);
10701071
extra->tensors[j] = wrapped;
@@ -1075,6 +1076,10 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
10751076
wrapped->nb[1] = wrapped->nb[1] / dims->ne[0] * splits.split[j];
10761077
wrapped->nb[2] = wrapped->nb[2] / dims->ne[0] * splits.split[j];
10771078
wrapped->nb[3] = wrapped->nb[3] / dims->ne[0] * splits.split[j];
1079+
1080+
if (offset_aware) {
1081+
wrapped->view_offs = offset_splits.split[j];
1082+
}
10781083
}
10791084
};
10801085

@@ -1720,11 +1725,10 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
17201725
else {
17211726
// a weight matrix is multiplied by a column split tensor (prior to ROPE), it can be massaged to a column split.
17221727
// this results in a reduce split.
1723-
ensure_row_split(src0);
1724-
ensure_rejoined(tensor, src1);
1725-
create_column_split_tensors();
1726-
set_src_tensor(0, GGML_TP_SPLIT_ROWS);
1727-
set_src_tensor(1, GGML_TP_SPLIT_NONE);
1728+
ensure_weight_column_split(src0);
1729+
create_reduce_tensors();
1730+
set_src_tensor(0, GGML_TP_SPLIT_COLUMNS);
1731+
set_src_tensor(1, GGML_TP_SPLIT_COLUMNS);
17281732
}
17291733
}
17301734
else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS && src1_split_tensors == GGML_TP_SPLIT_COLUMNS) {
@@ -2670,7 +2674,7 @@ static bool ggml_backend_tp_device_supports_op(ggml_backend_dev_t dev, const str
26702674
if (src0->ne[1] >= 4096)
26712675
return true;
26722676
if (src0->ne[1] * src0->ne[2] >= 4096) {
2673-
if (src0->ne[1] >= 2048)
2677+
if (src0->ne[1] >= 1024)
26742678
return true;
26752679
return false;
26762680
}

0 commit comments

Comments
 (0)