Skip to content

Conversation

@IMbackK
Copy link
Collaborator

@IMbackK IMbackK commented Jan 30, 2025

This adds selectable warp size support to mmv to improve performance on devices with warp size != 32

Predictably this improves performance on CDNA (and GCN)

Master:

  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B F16                   |  14.96 GiB |     8.03 B | ROCm       |  99 |          tg64 |         36.85 ± 0.10 |

PR:

  Device 0: AMD Instinct MI100, gfx908:sramecc+:xnack- (0x908), VMM: no, Wave Size: 64
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B F16                   |  14.96 GiB |     8.03 B | ROCm       |  99 |          tg64 |         49.38 ± 0.10 |

And dose nothing for RDNA2

Master:

  Device 0: AMD Radeon RX 6800 XT, gfx1030 (0x1030), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B F16                   |  14.96 GiB |     8.03 B | ROCm       |  99 |          tg64 |         26.77 ± 0.07 |

PR:

  Device 0: AMD Radeon RX 6800 XT, gfx1030 (0x1030), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl |          test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 8B F16                   |  14.96 GiB |     8.03 B | ROCm       |  99 |          tg64 |         26.81 ± 0.05 |

@IMbackK IMbackK force-pushed the addWarpSize branch 2 times, most recently from a151674 to 9a6a6ef Compare January 30, 2025 16:47
@IMbackK
Copy link
Collaborator Author

IMbackK commented Jan 30, 2025

i dont like the addition of GGML_TRUE_WARP_SIZE much, but i cant see another way that dosent:

  1. require moveing every kernel to selectable warp size at the same time
  2. loose developer intent by just hardcodeing 32.

@IMbackK
Copy link
Collaborator Author

IMbackK commented Jan 30, 2025

I also dont know if #define GGML_TRUE_WARP_SIZE 32 is correct for musa

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Jan 30, 2025
@IMbackK IMbackK force-pushed the addWarpSize branch 2 times, most recently from 09b02bc to f5dd31f Compare January 30, 2025 17:40
@Beinsezii
Copy link
Contributor

Does gfx11 not also support w64, or does that not matter here?

@IMbackK
Copy link
Collaborator Author

IMbackK commented Jan 31, 2025

RDNA can be run in wave 64 mode and on RDNA3 this can provide huge performance improvements as RDNA3 can dual issue halfs of a 64 wide wave for some operations, doubling throughput in these instances.

However rocm dose not support RDNA in wave 64 mode on hip:

https://github.com/ROCm/HIP/blob/c1f7109cdd0e7921403cea649baf24a3c38cdd20/include/hip/hip_runtime.h#L40

The reason for this is that the RDNA isa lacks some 64 wide across-wave opertaions in wave64 mode that hip requires.

Regardless if AMD somehow lifted this limitation and you compiled llamacpp with '-mwavefrontsize64' this pr would detect that we are now in wave64 mode and work fine.

in reality you will probably never see more than half peak throughput on rdna3 in regular generic hip code. Either you have to use V_PK 2x32bit instructions by hand in wave32 mode or WMMA, which also internally dual issues to the alus, where applicable.

@Beinsezii
Copy link
Contributor

Damn. They seem to have gfx11 just on the backburner for gfx94. Maybe we could open an issue on HIP just see if it gets any attention at least.

@BlueSwordM
Copy link

BlueSwordM commented Feb 1, 2025

Does this also work for GFX906 GPUs, like the Radeon VII/Mi50/Mi60?
I don't seem to be getting large speedups on my end:
cmake -B build -DGGML_HIP=ON -DAMDGPU_TARGETS=gfx906 -DBUILD_SHARED_LIBS=OFF

Device 0: AMD Radeon VII, gfx906:sramecc+:xnack- (0x906), VMM: no, Wave Size: 64

Is this only applicable to FP16 models?
I'm on ROCM 6.2.4 for reference, so far above the 5.5 requirement.
CachyOS (Arch) with 6.13.0 kernel.

@IMbackK
Copy link
Collaborator Author

IMbackK commented Feb 1, 2025

Is this only applicable to FP16 models? I'm on ROCM 6.2.4 for reference, so far above the 5.5 requirement. CachyOS (Arch) with 6.13.0 kernel.

This only affects mmv, quantized models mostly use mmvq so you should not expect anything with quantized models.

Damn. They seem to have gfx11 just on the backburner for gfx94. Maybe we could open an issue on HIP just see if it gets any attention at least.

No the isa just dosent support the required operations in wave64 mode, this is not something amd can solve.

@BlueSwordM
Copy link

Is this only applicable to FP16 models? I'm on ROCM 6.2.4 for reference, so far above the 5.5 requirement. CachyOS (Arch) with 6.13.0 kernel.

This only affects mmv, quantized models mostly use mmvq so you should not expect anything with quantized models.

That's what I guessed, thank you.
Would it theoretically be possible to perform such an operation with mmvq considering llama.cpp internally converts quantized int weights to FP16/FP32 at runtime? Is that even possible?

@IMbackK
Copy link
Collaborator Author

IMbackK commented Feb 1, 2025

sure its possible, its also the plan

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preference would be to somehow define constexpr int warp_size = 64 at the beginning of the kernel and then use that instead of the WARP_SIZE macro. How about this: define a function like constexpr __device__ ggml_cuda_get_physical_warp_size in common.cuh and make that function return 32 by default but 64 for specific AMD architectures and compile flags.

@IMbackK
Copy link
Collaborator Author

IMbackK commented Feb 2, 2025

@JohannesGaessler done

Co-authored-by: Johannes Gäßler <[email protected]>
@IMbackK IMbackK merged commit 396856b into ggml-org:master Feb 2, 2025
46 checks passed
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Feb 4, 2025
CUDA/HIP: add support for selectable warp size to mmv

Author : Uvos
@BodhiHu
Copy link
Contributor

BodhiHu commented Feb 5, 2025

GGML_TRUE_WARP_SIZE

Hi @IMbackK , fyi, the warp size should be 128 for MUSA SUDI and QY arch:

https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Chapter09

@IMbackK
Copy link
Collaborator Author

IMbackK commented Feb 5, 2025

We can adjust the return of ggml_cuda_get_physical_warp_size to return 128 on musa, but someone will have to test this regularly when changes are made to expand its use, as i of course lack the hardware to do so.

tinglou pushed a commit to tinglou/llama.cpp that referenced this pull request Feb 13, 2025
CUDA/HIP: add support for selectable warp size to mmv
orca-zhang pushed a commit to orca-zhang/llama.cpp that referenced this pull request Feb 26, 2025
CUDA/HIP: add support for selectable warp size to mmv
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Feb 26, 2025
CUDA/HIP: add support for selectable warp size to mmv
mglambda pushed a commit to mglambda/llama.cpp that referenced this pull request Mar 8, 2025
CUDA/HIP: add support for selectable warp size to mmv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants