@@ -226,10 +226,15 @@ static void unwrap_tensor(ggml_tensor * tensor, std::set<ggml_tensor *> & tensor
226226 if (i == 1 && tensor->op == GGML_OP_ROPE) {
227227 wrapped->src [i] = src_extra->tensors [j];
228228 }
229+
229230 if (i == 3 && tensor->op == GGML_OP_FLASH_ATTN_EXT) {
230231 wrapped->src [i] = src_extra->tensors [j];
231232 }
232233
234+ if (i == 1 && tensor->op == GGML_OP_GET_ROWS) {
235+ wrapped->src [i] = src_extra->tensors [j];
236+ }
237+
233238 if (tensor->op == GGML_OP_RMS_NORM) {
234239 ggml_tensor * check;
235240 if (src_extra->split_tensors == GGML_TP_SPLIT_REDUCE) {
@@ -261,7 +266,7 @@ static void unwrap_tensor(ggml_tensor * tensor, std::set<ggml_tensor *> & tensor
261266 }
262267
263268 if (!wrapped->src [i]) {
264- GGML_LOG_ERROR (" Tensor %s unwrap failure.\n " , tensor->name , src->name );
269+ GGML_ABORT (" Tensor %s unwrap failure.\n " , tensor->name , src->name );
265270 }
266271 }
267272 }
@@ -329,7 +334,7 @@ static bool is_split_compatible(ggml_tensor * tensor) {
329334 case GGML_OP_ROPE:
330335 case GGML_OP_FLASH_ATTN_EXT:
331336 case GGML_OP_CPY:
332- // case GGML_OP_GET_ROWS:
337+ case GGML_OP_GET_ROWS:
333338 return true ;
334339 default :
335340 return false ;
@@ -1625,8 +1630,10 @@ static enum ggml_status ggml_backend_tp_buffer_init_tensor(ggml_backend_buffer_t
16251630 auto src0 = tensor->src [0 ];
16261631 auto src0_extra = (ggml_tensor_parallel_extra *)src0->extra ;
16271632 }
1633+ else if (tensor->op == GGML_OP_GET_ROWS) {
1634+ // nothing to split.
1635+ }
16281636 else if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
1629- int i = 0 ;
16301637 ensure_row_split (tensor->src [0 ]);
16311638 ensure_dim2_split (tensor->src [1 ]);
16321639 ensure_dim2_split (tensor->src [2 ]);
0 commit comments