Skip to content

Conversation

createthis
Copy link
Contributor

@createthis createthis commented Aug 13, 2025

Disclaimer: I couldn't code my way out of a wet paper bag in C++. This is 100% vibe coded AI slop. Upstream issue is #15049

What?

This PR adds GGML_CUDA_ALLOW_LARGE_TENSORS. When enabled, it allows 64 bit sizes in the CUDA copy routines.

Q. What is the difference in INT_MAX and SIZE_MAX / 4? How much larger of a tensor will this accomodate?

A. The difference between INT_MAX and SIZE_MAX/4 is enormous:

INT_MAX: 2,147,483,647 bytes ≈ 2.00 GB
SIZE_MAX/4: 4,611,686,018,427,387,903 bytes ≈ 4,294,967,296 GB ≈ 4.3 PB

How?

cmake -B build -DGGML_CUDA=ON -DGGML_CUDA_FA_ALL_QUANTS=ON -DGGML_CUDA_ALLOW_LARGE_TENSORS=ON
cmake --build build --config Release

Then:

./build/bin/llama-server \
    --model /data/Qwen3-Coder-480B-A35B-Instruct-1M-GGUF/UD-Q4_K_XL/Qwen3-Coder-480B-A35B-Instruct-1M-UD-Q4_K_XL-00001-of-00006.gguf \
    --alias Qwen3-Coder-480B-A35B-Instruct-GGUF:UD-Q4_K_XL \
    --no-webui \
    --numa numactl \
    --threads 32 \
    --ctx-size 400000 \
    --n-gpu-layers 63 \
    -ot "blk\.(3|4|5|6|7|8|9|10|11|12|13)\.ffn_.*=CUDA0" \
    -ot exps=CPU \
    -ub 4096 -b 4096 \
    --cache-type-k q4_1 \
    --cache-type-v q4_1 \
    --seed 3407 \
    --prio 3 \
    --temp 0.7 \
    --top-p 0.8 \
    --top-k 20 \
    --repeat-penalty 1.05 \
    --min-p 0.0 \
    --log-colors on \
    --flash-attn on \
    --host 0.0.0.0 \
    --jinja \
    --port 11434

Why?

Cards with a lot of VRAM like the blackwell 6000 pro may enable us to use larger in-GPU context lengths than INT_MAX allows.

Results

Screenshot 2025-08-13 at 2 16 01 PM

This model starts out with 20-22 tok/s generation at 0 context, so that's pretty terrible performance. Still, when you absolutely, positively, MUST read a huge number of tokens, this may be a potential solution.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Aug 13, 2025
Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the copy operations register pressure is not an issue. It should be fine to just use 64 bit integers for everything without the need for an extra compile option.

@bitbottrap
Copy link

This works for me. I'm not familiar with CUDA but from the comments it sounds like the #ifdef fencing isn't required?

@createthis
Copy link
Contributor Author

This works for me. I'm not familiar with CUDA but from the comments it sounds like the #ifdef fencing isn't required?

I just saw @JohannesGaessler's comment a couple days ago. I'm currently focused on trying to get another PR pushed through, but I'll circle back around to implement the suggested change shortly.

@createthis
Copy link
Contributor Author

@JohannesGaessler I removed the compile option. I also ran this with LongBench for a few hours ( 15/502 tests ) just to ensure it was working: https://github.com/createthis/LongBench/pull/1/files

LongBench tested it out to 400k context:
Screenshot 2025-09-09 at 8 57 03 PM

Copy link
Collaborator

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you misunderstood what I meant: you should be using 64 bit values for the kernel arguments and add a loop that allows the kernel to iterate over an essentially arbitrarily large amount of data. Launching multiple CUDA kernels in chunks is a fundamentally bad solution.

(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
const char * cx, char * cdst, const int64_t ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
const int64_t nb12, const int64_t nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants