@@ -1355,15 +1355,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
13551355 std::vector<int32_t > ids;
13561356 std::vector<ggml_bitset_t > used_ids;
13571357
1358- for (int i = 0 ; i < sched->n_splits ; i ++) {
1359- struct ggml_backend_sched_split * split = &splits[i ];
1358+ for (int split_id = 0 ; split_id < sched->n_splits ; split_id ++) {
1359+ struct ggml_backend_sched_split * split = &splits[split_id ];
13601360 int split_backend_id = split->backend_id ;
13611361 ggml_backend_t split_backend = sched->backends [split_backend_id];
13621362
13631363 // copy the input tensors to the split backend
1364- for (int j = 0 ; j < split->n_inputs ; j ++) {
1365- ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [j ]);
1366- struct ggml_tensor * input = split->inputs [j ];
1364+ for (int input_id = 0 ; input_id < split->n_inputs ; input_id ++) {
1365+ ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [input_id ]);
1366+ struct ggml_tensor * input = split->inputs [input_id ];
13671367 struct ggml_tensor * input_cpy = tensor_copy (input, split_backend_id, sched->cur_copy );
13681368
13691369 if (input->flags & GGML_TENSOR_FLAG_INPUT) {
@@ -1398,17 +1398,30 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
13981398
13991399 // get the ids
14001400 ggml_tensor * ids_tensor = node->src [2 ];
1401+ ggml_backend_t ids_backend = split_backend;
1402+
1403+ // if the ids tensor is also an input of the split, it may not have been copied yet to the split backend
1404+ // in that case, we use the original ids tensor
1405+ for (int i = input_id + 1 ; i < split->n_inputs ; i++) {
1406+ if (ids_tensor == tensor_copy (split->inputs [i], split_backend_id, sched->cur_copy )) {
1407+ ids_tensor = split->inputs [i];
1408+ ids_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [i]);
1409+ break ;
1410+ }
1411+ }
1412+
14011413 if (ids_tensor != prev_ids_tensor) {
14021414 ids.resize (ggml_nbytes (ids_tensor) / sizeof (int32_t ));
1403- ggml_backend_tensor_get_async (split_backend , ids_tensor, ids.data (), 0 , ggml_nbytes (ids_tensor));
1404- ggml_backend_synchronize (split_backend );
1415+ ggml_backend_tensor_get_async (ids_backend , ids_tensor, ids.data (), 0 , ggml_nbytes (ids_tensor));
1416+ ggml_backend_synchronize (ids_backend );
14051417
14061418 // find the used experts
14071419 used_ids.clear ();
14081420 used_ids.resize (ggml_bitset_size (n_expert));
14091421 for (int64_t i1 = 0 ; i1 < ids_tensor->ne [1 ]; i1++) {
14101422 for (int64_t i0 = 0 ; i0 < ids_tensor->ne [0 ]; i0++) {
14111423 int32_t id = ids[i1 * ids_tensor->nb [1 ]/sizeof (int32_t ) + i0 * ids_tensor->nb [0 ]/sizeof (int32_t )];
1424+ GGML_ASSERT (id >= 0 && id < n_expert);
14121425 ggml_bitset_set (used_ids.data (), id);
14131426 }
14141427 }
0 commit comments