Skip to content

Commit 48cf9e4

Browse files
committed
Add heuristics to execute CUB branch only when it brings perf
Heuristics were determined on the following HW: * RTX 4000 SFF ADA * RTX 6000 ADA * RTX PRO 6000 Blackwell Max-Q * RTX PRO 4500 Blackwell
1 parent 7c7413e commit 48cf9e4

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

ggml/src/ggml-cuda/mean.cu

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,19 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2525

2626
// Special case for reducing vectors
2727
#ifdef USE_CUB
28-
if (nrows == 1) {
28+
cudaStreamCaptureStatus iscapturing;
29+
CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
30+
if ((nrows == 1) &&
31+
// CUDA_GRAPHS_DISABLED
32+
((ncols > 65536) &&
33+
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
34+
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
35+
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
36+
// CUDA_GRAPHS ENABLED
37+
((ncols > 32768) &&
38+
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
39+
ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
40+
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
2941
// Single row - use device-wide reduction
3042
size_t tmp_size = 0;
3143
ggml_cuda_pool & pool = ctx.pool();

0 commit comments

Comments
 (0)