Skip to content

Commit a02adf4

Browse files
committed
sampling : add assertions for contiguous tensors in async copy functions
1 parent 883a870 commit a02adf4

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/llama-context.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,8 @@ static void copy_tensor_async_ints(
12311231
const uint32_t row = it->second;
12321232
GGML_ASSERT(row < sampled_size);
12331233

1234+
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
1235+
12341236
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
12351237
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
12361238
}
@@ -1253,6 +1255,8 @@ static void copy_tensor_async_floats(
12531255
const uint32_t row = it->second;
12541256
GGML_ASSERT(row < counts.size());
12551257

1258+
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
1259+
12561260
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
12571261
float * row_ptr = dst + (size_t) row * stride;
12581262
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
@@ -1279,6 +1283,8 @@ static void copy_tensor_async_candidates(
12791283
const uint32_t row = it->second;
12801284
GGML_ASSERT(row < counts.size());
12811285

1286+
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
1287+
12821288
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
12831289
llama_token * row_ptr = dst + (size_t) row * stride;
12841290
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));

0 commit comments

Comments
 (0)