Skip to content

Conversation

@ORippler
Copy link
Contributor

@ORippler ORippler commented Aug 6, 2025

Investigation of Gemma3n perf on NVGPUs identified the reduce_rows_f32 kernel as a major performance bottleneck. Profiling revealed the kernel to be severely latency-limited in the regime run by by Gemma3n (nrows ~10, ncols in [2048, 8192]).

This PR addresses this issue, hiding the latency by a combination of:

  1. Manual loop unrolling, getting the compiler to request all unrolled datapoints at once, instead of fetching data sequentially (#pragma unroll did not do the trick unfortunately).
  2. Increasing the number of threads processing a row, where 512 threads are used for the low-parallelization regime (i.e. processing only a single row). This gives the SM 16 full warps to cycle through, further pipelining data fetching.

Since perf regressions were identified in the high-parallelization regime (nrows >= 2x SM count), we use:

  • 128 threads for medium-to-large columns, effectively letting each SM process a single row (a SM can execute 4 warps x 32 threads=128 threads concurrently).
  • As perf regression were still observed for small columns (< 1024 cols = 1 unrollment loop of a threadblock with size 128 and 8 unrolls), thread count was reduced to 32 threads for small columns. An alternative to this would have been to template the number of unrolls based on the column size. However, this would lead to an increased binary size due to the required compilation of multiple kernels, and was thus not pursued further.

The high/low parallelization threshold was empirically determined:

GPU Model Nrow SM Count Multiple, where 128 beats 512 threads
RTX 4000 SFF ADA 2.0x
RTX 6000 ADA 2.5x
RTX PRO 6000 Blackwell Max-Q 3.04x
RTX PRO 4500 Blackwell 3.15x

In total, up to ~25x perf improvement was observed on kernel-level.
speedup_comparison_multiple
Moreover, regression was not observed in any of the investigated combinations.
speedup_comparison_fractional

As a consequence of this general kernel optimization, Gemma3n achieves ~10% perf increase, going from 130 to 145 tok/s on a RTX PRO 6000 Blackwell-Max-Q with batch-size 1.

Naive:

  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | n_batch |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |           pp100 |        147.27 ± 0.82 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |           tg100 |        130.75 ± 0.28 |

Optimized

  Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes
| model                          |       size |     params | backend    | ngl | n_batch |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------: | --------------: | -------------------: |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |           pp100 |        168.41 ± 0.37 |
| gemma3n E2B Q8_0               |   4.45 GiB |     4.46 B | CUDA       |  99 |       1 |           tg100 |        146.68 ± 0.68 |

Side note: Similar tendencies were observed for rms_norm_f32, and we intend to optimize said kernel in a separate PR.

@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 6, 2025
This increases iteration cycle speed by not having to recompile
every kernel all the time
1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims
Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell	| 3.15x                  |
Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily
@ORippler ORippler force-pushed the osimons/optimize_reduce_rows_f32 branch from c6ed8cc to 9296d1f Compare August 7, 2025 07:46
@ORippler
Copy link
Contributor Author

ORippler commented Aug 7, 2025

Rebased on current master, resolving conflicts along the way. Reran E2E perf tests for gemma3n, and we continue to see perf gains. Nice to see some other optimizations for tg were made in master 😃

Naive:
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes

model size params backend ngl n_batch test t/s
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 pp100 146.89 ± 0.12
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 tg100 145.86 ± 0.13

Optimized:
Device 0: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition, compute capability 12.0, VMM: yes

model size params backend ngl n_batch test t/s
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 pp100 167.47 ± 0.29
gemma3n E2B Q8_0 4.45 GiB 4.46 B CUDA 99 1 tg100 167.28 ± 0.35

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect
@JohannesGaessler
Copy link
Collaborator

Thank you for answering my questions (even though I could have gotten the answers by reading the PR description more carefully). If you test using CUB for GGML_MEAN this PR would essentially be good to merge from my side.

@IMbackK
Copy link
Collaborator

IMbackK commented Aug 7, 2025

Quick test shows this pr is also boradly performance positive on CDNA and performance neutral on RDNA2

Currently this branch is only executed for nrows==1
Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell
Tests should run with CUDA Graphs enabled per default on NVGPUs
@ORippler
Copy link
Contributor Author

Thank you for answering my questions (even though I could have gotten the answers by reading the PR description more carefully). If you test using CUB for GGML_MEAN this PR would essentially be good to merge from my side.

@JohannesGaessler As requested, I put up a naive implementation that uses CUB for GGML_OP_MEAN. The implementation uses CUB to compute the device-wide sum, and another kernel to divide the sum by ncols (CUB does not offer a device-wide mean operation).
image
Benchmarks show that for small ncols, CUB-based implementation is slower than reduce_rows_f32, indicating it is worse at hiding latency of data access. Using >1 ThreadBlock for a single row however allows it to scale much better for high ncols. As a consequence, it outperforms reduce_rows_f32 there. Due to the CUB-based implementation having twice the kernel-launch-overhead of reduce_rows_f32 on CPU side, CUB-based implementation starts to outperform reduce_rows_f32 only at higher ncols when CUDA Graphs are disabled.

I reflected the above insights by branching the execution in ggml_cuda_op_mean accordingly. I did not implement nor benchmark a CUB-based implementation for nrows > 1, but expect it to be comparable to reduce_rows_f32 for most settings as we effectively parallelize rows across ThreadBlocks in current reduce_rows_f32 implementation. I also did not investigate writing a single kernel that uses more granular CUB primitives, as I presume we want to preserve a hipify-able kernel (see this comment).

@ORippler
Copy link
Contributor Author

I personally feel the CUB-based implementation to be a bit beyond the original scope of this PR. However, since I am unable to create branches in the base repo and am unaware of how to represent stacked PRs in Github for PRs filed across forks, I left it in here.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the high-effort PR.

@ORippler
Copy link
Contributor Author

@JohannesGaessler Could we get this merged whenever you have the time? Unfortunately I don't have write access 🙈

@JohannesGaessler JohannesGaessler merged commit 6028bf7 into ggml-org:master Aug 13, 2025
47 checks passed
@JohannesGaessler
Copy link
Collaborator

Ah sorry, I wanted to merge this yesterday (after the CI finishes) and I forgot about it.

@ORippler ORippler deleted the osimons/optimize_reduce_rows_f32 branch August 13, 2025 08:05
ggerganov pushed a commit to ggml-org/ggml that referenced this pull request Aug 13, 2025
…vement on kernel-level and 10% perf increase for Gemma3n (llama/15132)

* Factor out `reduce_rows_f32` from common.cuh

This increases iteration cycle speed by not having to recompile
every kernel all the time

* Hide memory-latency by loop unrolling in reduce_rows_f32

* Further optimizations to `reduce_rows_f32`

1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims

* Add perf tests for `reduce_rows_f32` kernel

* Add heuristic to toggle 128/512 threads based on sm count

Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell	| 3.15x                  |

* Ensure perf gains also for small ncols and large nrows

Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily

* Modify perf and unit-tests

* Apply auto-formatting by clang

* Fix CI build failure

See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486
Building with VS generator worked though.

* Remove sm_count property from `ggml_backend_cuda_context`

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect

* Add CUB-based implementation for GGML_OP_MEAN

Currently this branch is only executed for nrows==1

* Add heuristics to execute CUB branch only when it brings perf

Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell

* Add unit-test for CUB-based mean

Tests should run with CUDA Graphs enabled per default on NVGPUs

* Rename `USE_CUB` to `GGML_CUDA_USE_CUB`

Suggested by @JohannesGaessler

* Unindent Preprocessor directives

See
ggml-org/llama.cpp#15132 (comment)
@ggerganov
Copy link
Member

ggerganov pushed a commit to ggml-org/ggml that referenced this pull request Aug 14, 2025
…vement on kernel-level and 10% perf increase for Gemma3n (llama/15132)

* Factor out `reduce_rows_f32` from common.cuh

This increases iteration cycle speed by not having to recompile
every kernel all the time

* Hide memory-latency by loop unrolling in reduce_rows_f32

* Further optimizations to `reduce_rows_f32`

1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims

* Add perf tests for `reduce_rows_f32` kernel

* Add heuristic to toggle 128/512 threads based on sm count

Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell	| 3.15x                  |

* Ensure perf gains also for small ncols and large nrows

Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily

* Modify perf and unit-tests

* Apply auto-formatting by clang

* Fix CI build failure

See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486
Building with VS generator worked though.

* Remove sm_count property from `ggml_backend_cuda_context`

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect

* Add CUB-based implementation for GGML_OP_MEAN

Currently this branch is only executed for nrows==1

* Add heuristics to execute CUB branch only when it brings perf

Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell

* Add unit-test for CUB-based mean

Tests should run with CUDA Graphs enabled per default on NVGPUs

* Rename `USE_CUB` to `GGML_CUDA_USE_CUB`

Suggested by @JohannesGaessler

* Unindent Preprocessor directives

See
ggml-org/llama.cpp#15132 (comment)
ggerganov pushed a commit to ggml-org/whisper.cpp that referenced this pull request Aug 18, 2025
…vement on kernel-level and 10% perf increase for Gemma3n (llama/15132)

* Factor out `reduce_rows_f32` from common.cuh

This increases iteration cycle speed by not having to recompile
every kernel all the time

* Hide memory-latency by loop unrolling in reduce_rows_f32

* Further optimizations to `reduce_rows_f32`

1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims

* Add perf tests for `reduce_rows_f32` kernel

* Add heuristic to toggle 128/512 threads based on sm count

Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell	| 3.15x                  |

* Ensure perf gains also for small ncols and large nrows

Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily

* Modify perf and unit-tests

* Apply auto-formatting by clang

* Fix CI build failure

See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486
Building with VS generator worked though.

* Remove sm_count property from `ggml_backend_cuda_context`

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect

* Add CUB-based implementation for GGML_OP_MEAN

Currently this branch is only executed for nrows==1

* Add heuristics to execute CUB branch only when it brings perf

Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell

* Add unit-test for CUB-based mean

Tests should run with CUDA Graphs enabled per default on NVGPUs

* Rename `USE_CUB` to `GGML_CUDA_USE_CUB`

Suggested by @JohannesGaessler

* Unindent Preprocessor directives

See
ggml-org/llama.cpp#15132 (comment)
ggerganov pushed a commit to ggml-org/whisper.cpp that referenced this pull request Aug 18, 2025
…vement on kernel-level and 10% perf increase for Gemma3n (llama/15132)

* Factor out `reduce_rows_f32` from common.cuh

This increases iteration cycle speed by not having to recompile
every kernel all the time

* Hide memory-latency by loop unrolling in reduce_rows_f32

* Further optimizations to `reduce_rows_f32`

1. Increase threadblock size to better hide latency of memory requests.
   As a consequence of bigger threadblocks, do 2-step summation, using
   shared memory to communicate results between invocations
2. Use sum_temp array to reduce waits on sum
3. Adjust num_unroll to reflext bigger threadblock
4. Improve default block_dims, increase support for more block_dims

* Add perf tests for `reduce_rows_f32` kernel

* Add heuristic to toggle 128/512 threads based on sm count

Break even point was the minimum of the following multiples.

| GPU Model                     | Nrow SM Count Multiple |
| -----------                   | -----------            |
| RTX 4000 SFF ADA              | 2.0x                   |
| RTX 6000 ADA                  | 2.5x                   |
| RTX PRO 6000 Blackwell Max-Q  | 3.04x                  |
| RTX PRO 4500 Blackwell	| 3.15x                  |

* Ensure perf gains also for small ncols and large nrows

Alternative to this, one could have also made the number of unrollings
template-able, but that would require compiling the kernel multiple
times, increasing binary size unnecessarily

* Modify perf and unit-tests

* Apply auto-formatting by clang

* Fix CI build failure

See https://github.com/ggml-org/llama.cpp/actions/runs/16798370266/job/47573716079?pr=15132#step:7:486
Building with VS generator worked though.

* Remove sm_count property from `ggml_backend_cuda_context`

Requested by @JohannesGaessler, and should fix remaining CI issues as a
side-effect

* Add CUB-based implementation for GGML_OP_MEAN

Currently this branch is only executed for nrows==1

* Add heuristics to execute CUB branch only when it brings perf

Heuristics were determined on the following HW:

* RTX 4000 SFF ADA
* RTX 6000 ADA
* RTX PRO 6000 Blackwell Max-Q
* RTX PRO 4500 Blackwell

* Add unit-test for CUB-based mean

Tests should run with CUDA Graphs enabled per default on NVGPUs

* Rename `USE_CUB` to `GGML_CUDA_USE_CUB`

Suggested by @JohannesGaessler

* Unindent Preprocessor directives

See
ggml-org/llama.cpp#15132 (comment)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants