Skip to content

Commit ea3cab5

Browse files
WIP
1 parent c0358bd commit ea3cab5

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

ggml/src/ggml-backend.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,14 +1596,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
15961596
const int split_backend_id = split->backend_id;
15971597
ggml_backend_t split_backend = sched->backends[split_backend_id];
15981598

1599-
bool execute_inputs = false;
1599+
std::vector<ggml_tensor *> active_inputs;
16001600
// copy the input tensors to the split backend
16011601
for (int j = 0; j < split->n_inputs; j++) {
16021602
ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]);
16031603
struct ggml_tensor * input = split->inputs[j];
16041604
struct ggml_tensor * input_cpy = tensor_copy(input, split_backend_id, sched->cur_copy);
16051605
if (input_cpy->op != GGML_OP_NONE) {
1606-
execute_inputs = true;
1606+
active_inputs.push_back(input_cpy);
16071607
continue;
16081608
}
16091609

@@ -1635,12 +1635,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
16351635
}
16361636
}
16371637
}
1638-
if (execute_inputs) {
1638+
if (!active_inputs.empty()) {
16391639
ggml_cgraph graph_inputs = {
16401640
/*.size =*/ 0,
1641-
/*.n_nodes =*/ split->n_inputs,
1641+
/*.n_nodes =*/ int(active_inputs.size()),
16421642
/*.n_leafs =*/ 0,
1643-
/*.nodes =*/ split->inputs,
1643+
/*.nodes =*/ active_inputs.data(),
16441644
/*.grads =*/ NULL, // gradients would need visited_hash_set
16451645
/*.grad_accs =*/ NULL,
16461646
/*.leafs =*/ NULL,

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2656,8 +2656,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
26562656
}
26572657

26582658
#ifndef NDEBUG
2659-
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
2660-
ggml_backend_buft_is_cuda_split(node->buffer->buft));
2659+
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
26612660
for (int j = 0; j < GGML_MAX_SRC; j++) {
26622661
if (node->src[j] != nullptr) {
26632662
assert(node->src[j]->buffer);

0 commit comments

Comments
 (0)