Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1957,8 +1957,15 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct

size_t src1_stride_size = sizeof(cuda_t);

dim3 block_dims(ne13, ne12);
k_compute_batched_ptrs<<<1, block_dims, 0, main_stream>>>(
const int threads_x = 32;
const int threads_y = 32;
dim3 block_dims(threads_x, threads_y);

dim3 grid_dims(
(ne13 + threads_x - 1) / threads_x,
(ne12 + threads_y - 1) / threads_y
);
k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
src0_ptr, src1_ptr, dst_t,
ptrs_src.get(), ptrs_dst.get(),
ne12, ne13,
Expand Down
4 changes: 4 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6562,6 +6562,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 1024, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 1024, {3, 2}, {1, 1}));

// test cases with large batch size
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {1024, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {4096, 1}, {1, 1}));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a batch size of 1024 be enough to test this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1024 happens to be the maximum number of threads, so it won’t trigger any issues. I’ve now changed the batch size to 1536, which is smaller than before.

}
}
for (ggml_type type_a : other_types) {
Expand Down
Loading