@@ -1065,6 +1065,7 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
10651065 dims = dims ? dims : tensor;
10661066 extra->split_tensors = GGML_TP_SPLIT_COLUMNS;
10671067 auto splits = get_col_splits (dims);
1068+ auto offset_splits = get_dim_splits (tensor->view_offs );
10681069 for (size_t j = 0 ; j < ggml_parallel_devices.size (); j++) {
10691070 auto wrapped = prepare_wrapped (tensor, dims, offset_aware);
10701071 extra->tensors [j] = wrapped;
@@ -1075,6 +1076,10 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
10751076 wrapped->nb [1 ] = wrapped->nb [1 ] / dims->ne [0 ] * splits.split [j];
10761077 wrapped->nb [2 ] = wrapped->nb [2 ] / dims->ne [0 ] * splits.split [j];
10771078 wrapped->nb [3 ] = wrapped->nb [3 ] / dims->ne [0 ] * splits.split [j];
1079+
1080+ if (offset_aware) {
1081+ wrapped->view_offs = offset_splits.split [j];
1082+ }
10781083 }
10791084 };
10801085
@@ -1720,11 +1725,10 @@ static void do_init(size_t node_index, ggml_tensor * tensor, ggml_tensor_paralle
17201725 else {
17211726 // a weight matrix is multiplied by a column split tensor (prior to ROPE), it can be massaged to a column split.
17221727 // this results in a reduce split.
1723- ensure_row_split (src0);
1724- ensure_rejoined (tensor, src1);
1725- create_column_split_tensors ();
1726- set_src_tensor (0 , GGML_TP_SPLIT_ROWS);
1727- set_src_tensor (1 , GGML_TP_SPLIT_NONE);
1728+ ensure_weight_column_split (src0);
1729+ create_reduce_tensors ();
1730+ set_src_tensor (0 , GGML_TP_SPLIT_COLUMNS);
1731+ set_src_tensor (1 , GGML_TP_SPLIT_COLUMNS);
17281732 }
17291733 }
17301734 else if (src0_split_tensors == GGML_TP_SPLIT_COLUMNS && src1_split_tensors == GGML_TP_SPLIT_COLUMNS) {
@@ -2670,7 +2674,7 @@ static bool ggml_backend_tp_device_supports_op(ggml_backend_dev_t dev, const str
26702674 if (src0->ne [1 ] >= 4096 )
26712675 return true ;
26722676 if (src0->ne [1 ] * src0->ne [2 ] >= 4096 ) {
2673- if (src0->ne [1 ] >= 2048 )
2677+ if (src0->ne [1 ] >= 1024 )
26742678 return true ;
26752679 return false ;
26762680 }
0 commit comments