Skip to content

Conversation

@bugparty
Copy link
Contributor

@bugparty bugparty commented Sep 13, 2025

in the previous PR #14659 , JohannesGaessler said #14659 (comment)

This is going to produce correct results but generally speaking you will get much better performance if each thread just works on a single value instead of looping over ne0. However, it would also be fine to just merge it as-is and maybe change this later if it ever becomes relevant for end-to-end performance.

so I wrote a version without loops. it benefits the most for smaller tensor sizes, but in general it improves on any size as well.

by the way I added more test cases for PAD_REFLECT_1D.

here is the benchmark summary:

Benchmark Results

Tensor Shape (ne_a) Before (GB/s) After (GB/s) Improvement
[512, 34, 2, 1] 70.85 91.82 +29.6%
[3000, 80, 1, 1] 450.38 570.28 +26.6%
[3000, 80, 4, 1] 447.39 571.26 +27.7%
[3000, 384, 1, 1] 339.07 362.65 +6.9%
[3000, 384, 4, 1] 339.30 362.59 +6.9%

Highlights

  • Overall ~7–30% performance improvement depending on tensor shape.
  • Best gains observed on medium-sized tensors (width = 80).
  • Large tensors (width = 384) also show consistent improvements (~7%).

Overall Result

  • Average bandwidth gain across all tensor shapes:+15–16%.
  • Best improvement: ~+21% on small/medium tensors ([512,34,2,1], [3000,80,1,1], [3000,80,4,1]).
  • Moderate improvement: ~+7% on large tensors ([3000,384,1,1], [3000,384,4,1]).

raw benchmark data:

old kernel

.\test-backend-ops.exe perf -o PAD_REFLECT_1D -b CUDA0
Backend 1/2: CUDA0
  Device description: NVIDIA GeForce RTX 3080 Laptop GPU
  Device memory: 16383 MB (15253 MB free)

  PAD_REFLECT_1D(type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9):                540672 runs -     1.86 us/run -      138 kB/run -   70.85 GB/s
  PAD_REFLECT_1D(type=f32,ne_a=[3000,80,1,1],pad_0=10,pad_1=9):               253952 runs -     3.98 us/run -     1880 kB/run -  450.38 GB/s
  PAD_REFLECT_1D(type=f32,ne_a=[3000,384,1,1],pad_0=10,pad_1=9):               40887 runs -    25.39 us/run -     9028 kB/run -  339.07 GB/s
  PAD_REFLECT_1D(type=f32,ne_a=[3000,80,4,1],pad_0=10,pad_1=9):               253952 runs -     4.01 us/run -     1880 kB/run -  447.39 GB/s
  PAD_REFLECT_1D(type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9):               40887 runs -    25.38 us/run -     9028 kB/run -  339.30 GB/s

new kernel

.\test-backend-ops.exe perf -o PAD_REFLECT_1D 
Backend 1/2: CUDA0
  Device description: NVIDIA GeForce RTX 3080 Laptop GPU
  Device memory: 16383 MB (15253 MB free)

 PAD_REFLECT_1D(type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9):                696320 runs -     1.44 us/run -      138 kB/run -   91.82 GB/s
  PAD_REFLECT_1D(type=f32,ne_a=[3000,80,1,1],pad_0=10,pad_1=9):               319488 runs -     3.15 us/run -     1880 kB/run -  570.28 GB/s
  PAD_REFLECT_1D(type=f32,ne_a=[3000,80,4,1],pad_0=10,pad_1=9):               319488 runs -     3.14 us/run -     1880 kB/run -  571.26 GB/s
  PAD_REFLECT_1D(type=f32,ne_a=[3000,384,1,1],pad_0=10,pad_1=9):               44604 runs -    23.74 us/run -     9028 kB/run -  362.65 GB/s
  PAD_REFLECT_1D(type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9):               44604 runs -    23.75 us/run -     9028 kB/run -  362.59 GB/s

feat: add more test cases for PAD_REFLECT_1D
@github-actions github-actions bot added testing Everything test related Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Sep 13, 2025
@yosh20004
Copy link

Thanks for the PR — really nice work!

I was exploring a similar implementation and we both use threadblock layout related to ne0, but I think that keeping a loop structure in this kernel may hide memory access latency, so I want to introduce a parameter UNROLL to control grid-stride loop.

In short, I want ne0 / UNROLL per line, each process UNROLL elements in a single iteration, and i think your approach implied UNROLL=1.

Here are some benchmark results (Device: 3060M 6G):

Config Shape (ne_a) Runs Time (us/run) Size (kB/run) Bandwidth (GB/s) Speedup vs PR
UNROLL=4 [512,34,2,1] 491520 2.04 138 64.83 0.84x
[3000,80,1,1] 229376 4.45 1880 403.39 1.28x
[3000,384,1,1] 29736 35.02 9028 245.88 1.19x
[3000,80,4,1] 229376 4.47 1880 401.21 1.37x
[3000,384,4,1] 29736 35.07 9028 245.50 1.19x
UNROLL=1 [512,34,2,1] 565248 1.78 138 74.18 0.96x
[3000,80,1,1] 155648 6.67 1880 269.07 0.85x
[3000,384,1,1] 22302 46.04 9028 187.02 0.91x
[3000,80,4,1] 147456 6.85 1880 261.91 0.89x
[3000,384,4,1] 26019 42.07 9028 204.67 0.99x
Your PR [512,34,2,1] 589824 1.71 138 77.45 1.00x
[3000,80,1,1] 180224 5.69 1880 315.06 1.00x
[3000,384,1,1] 26019 41.79 9028 206.06 1.00x
[3000,80,4,1] 172032 6.12 1880 293.23 1.00x
[3000,384,4,1] 26019 41.55 9028 207.21 1.00x

not a big change, but it might improve robustness across workloads.

If you think it makes sense, you can take a look at my draft implementation:
https://github.com/yosh20004/llama.cpp/blob/feat/optimize-pad_reflect_1d/ggml/src/ggml-cuda/pad_reflect_1d.cu

@bugparty
Copy link
Contributor Author

Hummm, your code is really good, if the tensor shape is small, no loop version performs better, if the shape is large enough, your unrolled version is the best.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On some GPUs it's faster to pass float pointers and to calculate offsets like s01 = nb01 / sizeof(float) in host code.

@yosh20004
Copy link

On some GPUs it's faster to pass float pointers and to calculate offsets like s01 = nb01 / sizeof(float) in host code.

You are right. If the tensor is small, I guess we need to activate more cuda core to increase the occupancy of SM. Maybe we can have a strategy to set different UNROLL value according to the tensor size, but alse it will make code more complex.

Such as:

template<int UNROLL=1>
static __global__ void pad_reflect_1d_kernel_f32(...);

void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
    // same as your code
    if (TENSORSIZE <= LIMIT) {
        pad_reflect_1d_kernel_f32<1><<<grid, block, 0, stream>>>(...); // better for small tensor
    } else {
        pad_reflect_1d_kernel_f32<4><<<grid, block, 0, stream>>>(...); // better for large tensor
    }
}

I think the amount of LIMIT is related to the GPU architecture, in my device its about 32768 (i0 * i1 * i2 * i3), but I am not sure.

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot: try adding __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) to the kernel. This tells the compiler that the kernel will only ever be launched with 256 threads so the compiler can optimize it more aggressively.

@bugparty
Copy link
Contributor Author

I forgot: try adding __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) to the kernel. This tells the compiler that the kernel will only ever be launched with 256 threads so the compiler can optimize it more aggressively.

I tried,

I forgot: try adding __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) to the kernel. This tells the compiler that the kernel will only ever be launched with 256 threads so the compiler can optimize it more aggressively.

Hi, I did add it, but the performance stays the same,
by the way, do you open to launch different kernel depending on the tensor size? if do, @yosh20004 and me can make an new PR to speedup large tensor's speed on PAD_REFLECT_1D

@JohannesGaessler
Copy link
Collaborator

We don't need the absolute fastest implementation of padding, most likely it takes up almost nothing of the actual runtime. I'm willing to maintain a single variant of the kernel unless it can be demonstrated that there is a use case where having two kernels makes a meaningful difference for the end-to-end performance.

@bugparty
Copy link
Contributor Author

We don't need the absolute fastest implementation of padding, most likely it takes up almost nothing of the actual runtime. I'm willing to maintain a single variant of the kernel unless it can be demonstrated that there is a use case where having two kernels makes a meaningful difference for the end-to-end performance.

got it, thank you for the guidance, the final speed improvement of shape [512, 34, 2, 1] is +29.6%, Thank you for the help and hints.

@JohannesGaessler JohannesGaessler merged commit 38dbdf4 into ggml-org:master Sep 18, 2025
53 of 55 checks passed
yael-works pushed a commit to yael-works/llama.cpp that referenced this pull request Oct 15, 2025
* CUDA: Optimize PAD_REFLECT_1D
feat: add more test cases for PAD_REFLECT_1D

* use fast_div to improve performance

* Apply suggestion from JohannesGaessler

Co-authored-by: Johannes Gäßler <[email protected]>

* Apply suggestion from JohannesGaessler

Co-authored-by: Johannes Gäßler <[email protected]>

* optimize

* use a concise expression to further speedup the cuda kernel

---------

Co-authored-by: Johannes Gäßler <[email protected]>
pwilkin pushed a commit to pwilkin/llama.cpp that referenced this pull request Oct 23, 2025
* CUDA: Optimize PAD_REFLECT_1D
feat: add more test cases for PAD_REFLECT_1D

* use fast_div to improve performance

* Apply suggestion from JohannesGaessler

Co-authored-by: Johannes Gäßler <[email protected]>

* Apply suggestion from JohannesGaessler

Co-authored-by: Johannes Gäßler <[email protected]>

* optimize

* use a concise expression to further speedup the cuda kernel

---------

Co-authored-by: Johannes Gäßler <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants