Skip to content

Commit 8e04242

Browse files
committed
Ensure perf gains also for small ncols and large nrows
Alternative to this, one could have also made the number of unrollings template-able, but that would require compiling the kernel multiple times, increasing binary size unnecessarily
1 parent 80de672 commit 8e04242

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

ggml/src/ggml-cuda/mean.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1919
constexpr dim3 block_dims(512, 1, 1);
2020
reduce_rows_f32</*norm=*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
2121
} else {
22-
constexpr dim3 block_dims(128, 1, 1);
22+
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
2323
reduce_rows_f32</*norm=*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
2424
}
2525
}

ggml/src/ggml-cuda/sumrows.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int
77
const dim3 block_dims(512, 1, 1);
88
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
99
} else {
10-
const dim3 block_dims(128, 1, 1);
10+
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
1111
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
1212
}
1313
}
@@ -33,7 +33,7 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3333
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
3434
} else {
3535
// Enough active SMs to hide latency, use smaller blocks to allow better scheduling
36-
const dim3 block_dims(128, 1, 1);
36+
const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
3737
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
3838
}
3939
}

0 commit comments

Comments
 (0)