Skip to content

Commit 4f0aebb

Browse files
JohannesGaesslerNexesenex
authored andcommitted
CUDA: fix MMQ for non-contiguous src0, add tests (ggml-org#10021)
* CUDA: fix MMQ for non-contiguous src0, add tests * revise test code
1 parent f1737de commit 4f0aebb

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,20 +1538,13 @@ static void ggml_cuda_op_mul_mat(
15381538
const size_t nbytes_data = ggml_nbytes(src0);
15391539
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
15401540
dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
1541-
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
1542-
}
1543-
1544-
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
1545-
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
1546-
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
1547-
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
15481541
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
15491542
}
15501543

1551-
// If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
1544+
// If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
15521545
if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
1553-
const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
1554-
const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
1546+
const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
1547+
const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
15551548
CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
15561549
}
15571550

ggml/src/ggml.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4132,7 +4132,7 @@ GGML_CALL int64_t ggml_blck_size(enum ggml_type type) {
41324132

41334133
GGML_CALL size_t ggml_nbytes(const struct ggml_tensor * tensor) {
41344134
size_t nbytes;
4135-
size_t blck_size = ggml_blck_size(tensor->type);
4135+
const size_t blck_size = ggml_blck_size(tensor->type);
41364136
if (blck_size == 1) {
41374137
nbytes = ggml_type_size(tensor->type);
41384138
for (int i = 0; i < GGML_MAX_DIMS; ++i) {

0 commit comments

Comments
 (0)