Skip to content

Commit 80de672

Browse files
committed
Add heuristic to toggle 128/512 threads based on sm count
Break even point was the minimum of the following multiples. | GPU Model | Nrow SM Count Multiple | | ----------- | ----------- | | RTX 4000 SFF ADA | 2.0x | | RTX 6000 ADA | 2.5x | | RTX PRO 6000 Blackwell Max-Q | 3.04x | | RTX PRO 4500 Blackwell | 3.15x |
1 parent 9070af8 commit 80de672

File tree

6 files changed

+32
-13
lines changed

6 files changed

+32
-13
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,7 @@ struct ggml_cuda_graph {
844844
struct ggml_backend_cuda_context {
845845
int device;
846846
std::string name;
847+
int sm_count;
847848
cudaEvent_t copy_event = nullptr;
848849

849850
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
@@ -853,7 +854,8 @@ struct ggml_backend_cuda_context {
853854

854855
explicit ggml_backend_cuda_context(int device) :
855856
device(device),
856-
name(GGML_CUDA_NAME + std::to_string(device)) {
857+
name(GGML_CUDA_NAME + std::to_string(device)),
858+
sm_count(ggml_cuda_info().devices[device].nsm) {
857859
}
858860

859861
~ggml_backend_cuda_context();

ggml/src/ggml-cuda/cross-entropy-loss.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
131131
CUDA_CHECK(cudaGetLastError());
132132

133133
// Combine results from individual blocks:
134-
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
134+
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream, ctx.sm_count);
135135
}
136136

137137
void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-cuda/mean.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1414
const int64_t ncols = src0->ne[0];
1515
const int64_t nrows = ggml_nrows(src0);
1616

17-
constexpr dim3 block_dims(512, 1, 1);
1817
const dim3 block_nums(nrows, 1, 1);
19-
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
18+
if ((nrows / ctx.sm_count)< 2){
19+
constexpr dim3 block_dims(512, 1, 1);
20+
reduce_rows_f32</*norm=*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
21+
} else {
22+
constexpr dim3 block_dims(128, 1, 1);
23+
reduce_rows_f32</*norm=*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
24+
}
2025
}

ggml/src/ggml-cuda/sum.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using namespace cub;
1212

1313
#include <cstdint>
1414

15-
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
15+
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream, int & n_sm) {
1616
#ifdef USE_CUB
1717
size_t tmp_size = 0;
1818
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
@@ -41,5 +41,5 @@ void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
4141
ggml_cuda_pool & pool = ctx.pool();
4242
cudaStream_t stream = ctx.stream();
4343

44-
sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
44+
sum_f32_cuda(pool, src0_d, dst_d, ne, stream, ctx.sm_count);
4545
}

ggml/src/ggml-cuda/sum.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "common.cuh"
22

3-
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
3+
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream, int & n_sum);
44

55
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/sumrows.cu

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
#include "sumrows.cuh"
22
#include "reduce_rows.cuh"
33

4-
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
5-
constexpr dim3 block_dims(512, 1, 1);
4+
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream, int & n_sm) {
65
const dim3 block_nums(nrows, 1, 1);
7-
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
6+
if ((nrows / n_sm)< 2){
7+
const dim3 block_dims(512, 1, 1);
8+
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
9+
} else {
10+
const dim3 block_dims(128, 1, 1);
11+
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
12+
}
813
}
914

1015
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -20,8 +25,15 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2025
const int64_t ncols = src0->ne[0];
2126
const int64_t nrows = ggml_nrows(src0);
2227

23-
constexpr dim3 block_dims(512, 1, 1);
2428
const dim3 block_nums(nrows, 1, 1);
25-
26-
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
29+
30+
if ((nrows / ctx.sm_count)< 2){
31+
// Increase num threads to 512 for small nrows to better hide the latency
32+
const dim3 block_dims(512, 1, 1);
33+
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
34+
} else {
35+
// Enough active SMs to hide latency, use smaller blocks to allow better scheduling
36+
const dim3 block_dims(128, 1, 1);
37+
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
38+
}
2739
}

0 commit comments

Comments
 (0)