Skip to content

Conversation

mnehete32
Copy link
Contributor

Part of #14909

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 28, 2025
@JohannesGaessler JohannesGaessler merged commit c97dc09 into ggml-org:master Aug 28, 2025
48 checks passed
ExtReMLapin pushed a commit to ExtReMLapin/llama.cpp that referenced this pull request Aug 28, 2025
* CUDA: add conv2d

* CUDA: conv2d - correct formatting and added const
Minh141120 pushed a commit to menloresearch/llama.cpp that referenced this pull request Aug 29, 2025
* CUDA: add conv2d

* CUDA: conv2d - correct formatting and added const
@Green-Sky
Copy link
Collaborator

Green-Sky commented Aug 29, 2025

Just ran a test using sd.cpp, and for VAE, this is ~25 times slower than the im2col+mat_mul version.

some numbers from the same device:
cuda im2col+mat_mul -> ~1s
cuda new conv2d -> ~25s
vulkan im2col+mat_mul (cm2) -> ~2.6s
vulkan conv2d (cm2) -> ~0.70s

So there is still a lot of room for improvements. (:

edit: also this code works, as far as I can tell :)

@mnehete32
Copy link
Contributor Author

mnehete32 commented Aug 29, 2025

yeah it’s definitely slower — I wasn’t sure if I could actually use a memory buffer for this, or how big of one would be okay, so I just went with this approach. if it’s fine to use something like

#define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024)

same used in cpu conv2d conv2d cpu pull request then I can switch over and do the im2col + mat_mul version.

I'd reuse the existing im2col kernel and call cublas on that patch for the mat_mul. @Green-Sky

@Green-Sky
Copy link
Collaborator

yeah it’s definitely slower — I wasn’t sure if I could actually use a memory buffer for this, or how big of one would be okay, so I just went with this approach. if it’s fine to use something like

#define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024)

same used in cpu conv2d conv2d cpu pull request then I can switch over and do the im2col + mat_mul version.

I'd reuse the existing im2col kernel and call cublas on that patch for the mat_mul. @Green-Sky

Yeah, doing "a tilled Im2col + GEMM approach" similar to the cpu implementation should work.

@mnehete32
Copy link
Contributor Author

working on it

@JohannesGaessler
Copy link
Collaborator

You can formulate a convolution as a matrix multiplication more generally. For optimal performance (for large input tensors), what would need to be done is load the data into shared memory, then load it into registers and use tensor cores. IIRC you need a minimum number of channels to fully utilize tensor cores so I think it will also be necessary to write variants with different memory organization pattern.

@rmatif
Copy link
Collaborator

rmatif commented Aug 29, 2025

You can formulate a convolution as a matrix multiplication more generally. For optimal performance (for large input tensors), what would need to be done is load the data into shared memory, then load it into registers and use tensor cores. IIRC you need a minimum number of channels to fully utilize tensor cores so I think it will also be necessary to write variants with different memory organization pattern.

Can't we just use padding to align the channel dimension with the tensor cores requirements?

@JohannesGaessler
Copy link
Collaborator

You can but for e.g. an RGB image with 3 channels you would be wasting at least 5/8 of the compute.

@rmatif
Copy link
Collaborator

rmatif commented Aug 29, 2025

According to the cudnn documentation, the number of channels must be a multiple of 8 to use tensor cores, so they apply padding under the hood. I think there may be some scenarios where it’s still worth taking the tensor core path even at half speed, rather than using regular cores

@JohannesGaessler
Copy link
Collaborator

Using tensor cores with padding will be faster than not using them, but using tensor cores with a memory access pattern that gets higher utilization for less than 8 channels will likely be even better.

@etasnadi
Copy link
Contributor

Hi,

Thanks for your contribution.

Have you checked the Vulkan CONV_2D implementation? I compared the perf of this code and it is 8 to 10 times slower than the Vulkan impl on my RTX 2060 device. So it hevily underutilizes the GPU for some reason.

Is there any advantage of this impl over the Vulkan kernel? If not, might be better to translate the Vulkan kernel or fix the issues with this kernel. I've already done this, but did not commit because I could not reach the perf of cuBLAS yet...

Vulkan:

  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     41 runs - 24893.68 us/run - 137.42 GFLOP/run -   5.52 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               20196 runs -    51.37 us/run - 133.69 MFLOP/run -   2.60 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               14003 runs -    72.59 us/run - 135.78 MFLOP/run -   1.87 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                196608 runs -     5.18 us/run - 642.82 kFLOP/run - 123.98 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                47860 runs -    22.75 us/run -  20.90 MFLOP/run - 918.35 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                73728 runs -    13.97 us/run -   2.78 MFLOP/run - 199.36 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                13467 runs -    92.25 us/run -  22.28 MFLOP/run - 241.49 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               22542 runs -    45.59 us/run - 115.40 MFLOP/run -   2.53 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3161 runs -   320.96 us/run - 923.24 MFLOP/run -   2.88 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     1595 runs -   642.80 us/run -   1.85 GFLOP/run -   2.88 TFLOPS
  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     42 runs - 23886.05 us/run - 137.42 GFLOP/run -   5.75 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               20196 runs -    51.05 us/run - 133.69 MFLOP/run -   2.62 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               14003 runs -    72.37 us/run - 135.78 MFLOP/run -   1.88 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                196608 runs -     5.15 us/run - 642.82 kFLOP/run - 124.77 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                47860 runs -    22.66 us/run -  20.90 MFLOP/run - 922.24 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                73728 runs -    13.95 us/run -   2.78 MFLOP/run - 199.68 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                13467 runs -    92.79 us/run -  22.28 MFLOP/run - 240.10 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               22542 runs -    45.19 us/run - 115.40 MFLOP/run -   2.55 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3161 runs -   320.30 us/run - 923.24 MFLOP/run -   2.88 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     1595 runs -   640.66 us/run -   1.85 GFLOP/run -   2.89 TFLOPS

CUDA (0320ac5):

ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 2060 SUPER, compute capability 7.5, VMM: yes
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 2060 SUPER (NVIDIA) | uma: 0 | fp16: 1 | bf16: 0 | warp size: 32 | shared memory: 49152 | int dot: 0 | matrix cores: KHR_coopmat
Testing 3 devices

Backend 1/3: CUDA0
  Device description: NVIDIA GeForce RTX 2060 SUPER
  Device memory: 7787 MB (7693 MB free)
  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      4 runs - 327012.75 us/run - 137.42 GFLOP/run - 420.23 GFLOPS
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2992 runs -   359.10 us/run - 133.69 MFLOP/run - 372.30 GFLOPS
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2948 runs -   364.18 us/run - 135.78 MFLOP/run - 372.84 GFLOPS
CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                139264 runs -     7.49 us/run - 642.82 kFLOP/run -  85.77 GFLOPS
CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                14358 runs -    92.05 us/run -  20.90 MFLOP/run - 227.02 GFLOPS
CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                24576 runs -    49.56 us/run -   2.78 MFLOP/run -  56.19 GFLOPS
CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4489 runs -   386.56 us/run -  22.28 MFLOP/run -  57.63 GFLOPS
CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3468 runs -   328.07 us/run - 115.40 MFLOP/run - 351.77 GFLOPS
CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 436 runs -  2526.51 us/run - 923.24 MFLOP/run - 365.42 GFLOPS
CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      220 runs -  4768.34 us/run -   1.85 GFLOP/run - 387.74 GFLOPS
CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      3 runs - 345780.33 us/run - 137.42 GFLOP/run - 397.43 GFLOPS
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2992 runs -   378.79 us/run - 133.69 MFLOP/run - 352.95 GFLOPS
CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                2948 runs -   384.51 us/run - 135.78 MFLOP/run - 353.13 GFLOPS
CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                131072 runs -     7.66 us/run - 642.82 kFLOP/run -  83.95 GFLOPS
CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                14358 runs -    94.53 us/run -  20.90 MFLOP/run - 221.05 GFLOPS
CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                24576 runs -    50.24 us/run -   2.78 MFLOP/run -  55.43 GFLOPS
CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4489 runs -   390.08 us/run -  22.28 MFLOP/run -  57.11 GFLOPS
CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3468 runs -   337.92 us/run - 115.40 MFLOP/run - 341.51 GFLOPS
CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 436 runs -  2603.31 us/run - 923.24 MFLOP/run - 354.64 GFLOPS
CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      220 runs -  4947.06 us/run -   1.85 GFLOP/run - 373.73 GFLOPS
  Backend CUDA0: OK
Backend 2/3: Vulkan0
  Skipping
Backend 3/3: CPU
  Skipping
3/3 backends passed
OK

@etasnadi
Copy link
Contributor

According to the cudnn documentation, the number of channels must be a multiple of 8 to use tensor cores, so they apply padding under the hood. I think there may be some scenarios where it’s still worth taking the tensor core path even at half speed, rather than using regular cores

No need to explicitly pad because the block size is aligned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants