Skip to content

Conversation

@mnehete32
Copy link
Contributor

Added Tensor Core to the code from #16088, have made modification such that it was giving best result on tensor cores. Below result are on RTX 2070 gpu.

FP16 Tensor Core perf

  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     55 runs - 18401.09 us/run - 137.42 GFLOP/run -   7.47 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               28424 runs -    35.24 us/run - 133.69 MFLOP/run -   3.79 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               19899 runs -    50.62 us/run - 135.78 MFLOP/run -   2.68 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                122880 runs -     8.58 us/run - 642.82 kFLOP/run -  74.95 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                38288 runs -    28.19 us/run -  20.90 MFLOP/run - 741.40 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                57344 runs -    18.43 us/run -   2.78 MFLOP/run - 151.07 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 8978 runs -   134.73 us/run -  22.28 MFLOP/run - 165.35 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               28611 runs -    34.96 us/run - 115.40 MFLOP/run -   3.30 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                4251 runs -   235.69 us/run - 923.24 MFLOP/run -   3.92 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     3465 runs -   293.17 us/run -   1.85 GFLOP/run -   6.31 TFLOPS

@etasnadi @Green-Sky @JohannesGaessler

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Oct 28, 2025
@mnehete32
Copy link
Contributor Author

Keeping this as a draft until the implicit or Vulkan changes are merged. I’ll integrate the tensor core kernel with that code.

@mnehete32
Copy link
Contributor Author

Hey @Green-Sky, could we also get a sd.cpp perf analysis for this draft?

I’ve exposed the tensor core kernel through conv2d_direct.

@Green-Sky
Copy link
Collaborator

Green-Sky commented Nov 1, 2025

Ran a bench on this pr and added it here #15805 (comment) .

Looks like this is now the fastest version !

VAE decoding is also slightly faster than im2col+matmul (maybe, might be within error).


sd1 fp16 512x768

method time sample memory sampling time decoding memory decoding
CUDA imcol+mul 0.21s 189.38 MB 0.75s 2496.09 MB
CUDA direct (master) 2.96s 132.71 MB 16.79s 1056.09 MB
CUDA direct (bssrdf pr c1f67c1) 0.37s 132.71 MB 1.00s 1056.09 MB
CUDA direct (mnehete32_tensor pr e3f94c6) 0.30s 132.71 MB 0.74s 1056.09 MB

sd1 fp16 768x1024 (like the old table)

method time sample memory sampling time decoding memory decoding
CUDA imcol+mul 0.58s 373.64 MB 1.55s 4992.19 MB
CUDA direct (master) 6.29s 260.30 MB 34.94s 2112.19 MB
CUDA direct (bssrdf pr c1f67c1) 0.85s 260.30 MB 2.03s 2112.19 MB
CUDA direct (mnehete32_tensor pr e3f94c6) 0.73s 260.30 MB 1.52s 2112.19 MB

sdxl fp16/q8_0 1024x1280

Diffusion model is q8_0 and vae is fp16.

method time sample memory sampling time decoding memory decoding
CUDA imcol+mul 0.79s 614.83 MB OOM 9600.31 MiB (alloc error)
CUDA direct (master) 11.05s 288.43 MB 60.57s 4800.31 MB
CUDA direct (bssrdf pr c1f67c1) 1.15s 288.43 MB 3.72s 4800.31 MB
CUDA direct (mnehete32_tensor pr e3f94c6) 1.00s 288.43 MB 2.92s 4800.31 MB

__constant__ __device__ Params P;

// see init_fastdiv_values in ggml-vulkan.cpp
__inline__ __device__ uint fastdiv(uint n, uint mp, uint L) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already exists in common.

static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {


#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))

static uint32_t ceil_div(uint32_t M, uint32_t N);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr size_t ceil_div(const size_t m, const size_t n) {

#include "convert.cuh"
#include "mma.cuh"

#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove makro, and use function instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants