English | ็ฎไฝไธญๆ |
Efficient, easy-to-use platform for inference and serving local LLMs including an OpenAI compatible API server.
- OpenAI compatible API server provided for serving LLMs.
- Highly extensible trait-based system to allow rapid implementation of new module pipelines,
- Streaming support in generation.
- Efficient management of key-value cache with PagedAttention.
- Continuous batching (batched decoding for incoming requests over time).
In-situquantization (andIn-situmarlin format conversion)GPTQ/Marlinformat quantization (4-bit)- Support
Mac/Metaldevices - Support
Multi-GPUinference (bothmulti-processandmulti-threadedmode) - Support
Multi-nodeinference with MPI runner - Support Chunked Prefilling (default chunk size 8K)
- Support CUDA Graph
- Support Model Context Protocol (MCP) and OpenAI-compatible tool calling
- Support Prefix Caching
- Support Block-wise FP8 Models (SM90+, Qwen3 Series)
- Support Flashinfer Backend
-
Currently, candle-vllm supports chat serving for the following model structures.
Show supported model architectures
Model ID Model Type Decoding Speed / Request ( BF16, Hopper)Quantized ( Q4KorMarlin)#1 LLAMA 105 tks/s (8B) 154 tks/s (8B, Q4k), 163 tks/s (8B, Marlin) #2 Mistral 112 tks/s (7B) 171 tks/s (7B, Q4k), 175 tks/s (7B, Marlin) #3 Phi3/Phi4 139 tks/s (3.8B) 180 tks/s (3.8B, Q4k) #4 QWen2/Qwen3 Dense 96 tks/s (8B) 135 tks/s (8B, Q4k) #5 QWen3 MoE 92 tks/s (30B) 114 tks/s (30B, Q4K) #6 QWen3-Next MoE 71 tks/s (80B, BF16, tp=2) TBD #7 QWen3.5 Dense 30 tks/s (27B, BF16) ~42 tks/s (27B, Q4K / FP8) #8 QWen3.5 MoE 82 tks/s (35B) 93 tks/s (35B, Q4K) #9 Yi 148 tks/s (6B) 180 tks/s (6B, Q4k) #10 StableLM 223 tks/s (3B) - #11 Gemma-2/Gemma-3 92 tks/s (9B) 115 tks/s (9B, Marlin) #12 DeepSeek V2/V3/R1 TBD ~20 tks (AWQ 671B, tp=8, offloading) #13 QwQ-32B 45 tks/s (32B, tp=2) 63 tks/s (32B, Q4K) #14 GLM4 89 tks/s (9B) 124 tks/s (9B, Q4K)
-
Nvidia GPU and Apple Silicon
Clone code
git clone git@github.com:EricLBuehler/candle-vllm.git
cd candle-vllmCUDA (CUDA 11+, 12+, 13.0)
Option 1 (Install into docker)
# Host driver version must >= specified cuda version, `flashattn` and `flashinfer` take longer time to build
# Change `sm_80` to your hardware spec, e.g., sm_75 (V100), sm_80 (Ampere, A100), sm_86/89 (RTX30xx, RTX40xx), sm_90 (Hopper, H100/H200), sm_100/sm_120 (Blackwell, RTX50xx).
./build_docker.sh "cuda,nccl,graph,flashinfer,cutlass" sm_90 13.0.0
# Or switch to Flash attention backend, or use Rust crate China Mirror (used in Chinese Mainland)
./build_docker.sh "cuda,nccl,graph,flashattn,cutlass" sm_80 12.9.0 1Option 2 (Manual Installation)
Install dependencies
sudo apt update
# Install CUDA toolkit (optional)
sudo apt install git libssl-dev pkg-config curl -y
sudo apt install -y cuda-toolkit-12-9 # must <= host driver version
# Install rust, 1.83.0+ required
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
# Make sure the CUDA Toolkit can be found in the system PATH
export PATH=$PATH:/usr/local/cuda/bin/Install for single node inference
# Remove "flashattn,flashinfer,cutlass" for sm_75 and sm_70
# Replace `flashinfer` with `flashattn` to use Flash attention backend
cargo install --features cuda,nccl,graph,flashinfer,cutlass --path .Install for multinode inference
# Use MPI (multi-gpus on multiple machines)
sudo apt install libopenmpi-dev openmpi-bin -y #install mpi
sudo apt install clang libclang-dev
cargo install --features cuda,nccl,graph,flashattn,cutlass,mpi --path .
# FlashInfer backend
cargo install --features cuda,nccl,graph,flashinfer,cutlass,mpi --path .Mac/Metal (single-node only)
Install Xcode command line tools
Install with metal feature
cargo install --features metal --path .-
[
ENV_PARAM] cargo run [BUILD_PARAM] -- [PROGRAM_PARAM] [MODEL_ID/MODEL_WEIGHT_PATH] [CACHE CONFIG] [WEB UI]Show details
Example:
[RUST_LOG=warn] cargo run [--release --features cuda,nccl,flashinfer,cutlass,graph] -- [--log --dtype bf16 --p 2000 --d 0,1 --gpu-memory-fraction 0.7 --isq q4k --prefill-chunk-size 8192 --frequency-penalty 1.1 --presence-penalty 1.1 --enforce-parser qwen_coder] [--m Qwen/Qwen3.5-27B-FP8] [--fp8-kvcache] [--ui-server]ENV_PARAM: RUST_LOG=warnBUILD_PARAM: --release --features cuda,nccl,flashinfer,cutlass,graphPROGRAM_PARAM๏ผ--log --dtype bf16 --p 2000 --d 0,1 --gpu-memory-fraction 0.7 --isq q4k --prefill-chunk-size 8192 --frequency-penalty 1.1 --presence-penalty 1.1 --enforce-parser qwen_coderMODEL_ID/MODEL_WEIGHT_PATH: --m Qwen/Qwen3.5-27B-FP8 (or--wspecify local model path)CACHE CONFIG: --fp8-kvcacheWEB UI: --ui-serverwhere,
--p: server port;--d: device ids;--w: weight path (safetensors folder);--f: weight file (for gguf);--m: huggingface model-id;--isq q4k: convert weights intoq4kformat during model loading;--prefill-chunk-sizechunk the prefill into size defined in this flag (default 8K,0for disable);--frequency-penaltyand--presence-penaltyrepetition penalty (value from -2.0 to 2.0);--mem(kvcache-mem-gpu) sets a fixed KV cache budget in MB;--gpu-memory-fractionauto-sizes KV cache after model load usingfraction * remaining_gpu_memory;--enforce-parserforces a specific tool parser backend such asqwen_coder,qwen,json, ormistral;--fp8-kvcacheused to enable fp8 kvcache;--prefix-cacheenable prefix cache reuse;--prefix-cache-max-tokenscap prefix cache size;--ui-serverstart with a built-in ChatGPT-like Web UI sever. ReplaceflashinferinBUILD_PARAMwithflashattnto use the Flash attention backend instead.
- Note: for docker build, execute the following command to enter docker:
docker run --rm -it --gpus all --network host -v /home:/home -v /data:/data candle-vllm:latest bash-
Run Uncompressed models
Show command
Local Path (with port, device)
candle-vllm --p 8000 --d 0,1 --w /home/Qwen3-30B-A3B-Instruct-2507/ --prefix-cache
Local Path (ISQ, +UI Server)
candle-vllm --p 8000 --d 0,1 --w /home/Qwen3.5-27B/ --isq q4k --ui-server --prefix-cache
Model-ID (download from Huggingface)
candle-vllm --m Qwen/Qwen3.5-35B-A3B --ui-server --prefix-cache
FP8 Model (block-wise quant, build with
cutlassfeature)candle-vllm --m Qwen/Qwen3.5-27B-FP8 --ui-server --prefix-cache
# MacOS/Metal (Dense) candle-vllm --m Qwen/Qwen3-4B-Instruct-2507-FP8 --ui-server --prefix-cache -
Run GGUF models
Show command
Local Path
candle-vllm --f /home/data/Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf --ui-server
Model-ID (download from Huggingface)
candle-vllm --m unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF --f Qwen3-30B-A3B-Instruct-2507-Q4_K_M.gguf --ui-server
-
Run GGUF models on Apple Silicon
Show command
Local Path (assume model downloaded in /home)
candle-vllm --f /home/qwq-32b-q4_k_m.gguf --ui-server
Model-ID (download from Huggingface)
candle-vllm --m Qwen/QwQ-32B-GGUF --f qwq-32b-q4_k_m.gguf --ui-server
-
Run Any uncompressed models as quantized with in-situ quantization
Show command
Simply add
isqparameter when running unquantized modelscandle-vllm --p 2000 --m Qwen/Qwen3.5-27B --isq q4k
Options for in-site
isqparameters: ["q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2k", "q3k","q4k","q5k","q6k"] -
Run Marlin-compatible GPTQ models models (4-bit GPTQ, 128-group, desc_act=False)
Show command
Local Path
candle-vllm --w /home/DeepSeek-R1-Distill-Qwen-14B-GPTQ_4bit-128g
Model-ID (download from Huggingface)
candle-vllm --m thesven/Llama-3-8B-GPTQ-4bit
Convert Any uncompressed model to marlin-compatible format
python3 examples/convert_marlin.py --src /home/DeepSeek-R1-Distill-Qwen-14B/ --dst /home/DeepSeek-R1-Distill-Qwen-14B-GPTQ_4bit-128g candle-vllm --w /home/DeepSeek-R1-Distill-Qwen-14B-GPTQ_4bit-128g
-
Run Marlin-compatible AWQ models models
Show command
Convert AWQ model to Marlin-compatible format
python3 examples/convert_awq_marlin.py --src /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4/ --dst /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4-Marlin/ --bits 4 --method awq --group 128 --nk False
Run the converted AWQ model
candle-vllm --d 0 --w /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4-Marlin/
-
Run Marlin-format models
Show command
candle-vllm --w /home/DeepSeek-R1-Distill-Qwen-14B-GPTQ-Marlin/
-
Run Large models using multi-process mode (Multi-GPU)
Show command
QwQ-32B BF16 model on two GPUs
candle-vllm --d 0,1 --w /home/QwQ-32B/
QwQ-32B 4-bit AWQ model on two GPUs
- Convert AWQ model to Marlin-compatible format
python3 examples/convert_awq_marlin.py --src /home/QwQ-32B-AWQ/ --dst /home/QwQ-32B-AWQ-Marlin/ --bits 4 --method awq --group 128 --nk False
- Run the converted AWQ model
candle-vllm --d 0,1 --w /home/QwQ-32B-AWQ-Marlin/
Note: number of GPUs (
--d) used must be aligned to 2^n (e.g., 2, 4, or 8). -
Run Large models using multi-threaded mode (Multi-GPU, for debug purpose)
Show command
Simply add the
--multithreadparameterQwQ-32B BF16 model on two GPUs
candle-vllm --multithread --d 0,1 --w /home/QwQ-32B/
If you encountered problems under Multi-threaded Multi-GPU mode, you may:
export NCCL_P2P_DISABLE=1 # disable p2p cause this feature can cause illegal memory access in certain environments
-
Run DeepSeek-R1 (671B/685B) on Lower GPU Memories (CPU offloading)
Show command
1. Convert DeepSeek-R1-AWQ model to Marlin-compatible format
python3 examples/convert_awq_marlin.py --src /data/DeepSeek-R1-AWQ/ --dst /data/DeepSeek-R1-AWQ-Marlin/
2. Run DeepSeek-R1 model on 8 x A100(40GB)
candle-vllm --log --d 0,1,2,3,4,5,6,7 --w /data/DeepSeek-R1-AWQ-Marlin/--num-experts-offload-per-rank 15
Note: This setup offloads 15 experts per rank (a total of 120 out of 256 experts) to the CPU (around 150GB additional host memory required). During inference, these offloaded experts are swapped back into GPU memory as needed. If you have even less GPU memory, consider increasing the
--num-experts-offload-per-rankparameter (up to a maximum of 32 experts per rank in this case). -
Run DeepSeek-R1 (671B/685B) on Multi-node
Show command
1. Install MPI and build with MPI feature
sudo apt update sudo apt install libopenmpi-dev openmpi-bin -y #install mpi sudo apt install clang libclang-dev #clone the repo on the same directory of the two node and build cargo install --features cuda,nccl,mpi #build with mpi feature
2. Convert AWQ deepseek to Marlin-compatible format
python3 examples/convert_awq_marlin.py --src /data/DeepSeek-R1-AWQ/ --dst /data/DeepSeek-R1-AWQ-Marlin/
3. Config Multi-node Environment
MPI Runner requires
identicalhardware and software configurations for all nodes, please ensure weights and candle-vllm binaries located in the identical folders in difference nodes. The the nodes need to be ssh (port 22 in this case) passwordless for each other (root user if--allow-run-as-root).%NET_INTERFACE%is the active network interface obtained through command 'ifconfig -a'. You may disable InfiniBand if it's not available in the nodes by insert env "-x NCCL_IB_DISABLE=1". Where,hostfilecan be defined as:Example (two nodes, each with 8 GPUs)
192.168.1.100 slots=8 192.168.1.101 slots=84. Run the model on two nodes with MPI runner
sudo mpirun -np 16 -x RUST_LOG=info -hostfile ./hostfile --allow-run-as-root -bind-to none -map-by slot --mca plm_rsh_args "-p 22" --mca btl_tcp_if_include %NET_INTERFACE% candle-vllm --log --d 0,1,2,3,4,5,6,7 --w /data/DeepSeek-R1-AWQ-Marlin/ -
Run with NUMA binding
Show command
Prerequisite Ensure your machine has more than one NUMA node (i.e., more than one physical CPU), and install numactl:
sudo apt-get install numactl
Suppose your machine has 8 GPUs and 2 NUMA nodes, with each set of 4 GPUs bound to a different NUMA node. To achieve optimal performance during inference using all GPUs, use the following NUMA binding:
MAP_NUMA_NODE=0,0,0,0,1,1,1,1 numactl --cpunodebind=0 --membind=0 candle-vllm --d 0,1,2,3,4,5,6,7 --w /home/data/DeepSeek-V2-Chat-AWQ-Marlin
To use only 4 GPUs, you can apply this NUMA binding:
MAP_NUMA_NODE=0,0,0,0 numactl --cpunodebind=0 --membind=0 candle-vllm --d 0,1,2,3 --w /home/data/DeepSeek-V2-Chat-AWQ-Marlin
where
numactl --cpunodebind=0 --membind=0above indicates NUMA binding for the master rank (master process) which should be matched toMAP_NUMA_NODE.Note: The exact NUMA binding sequence may vary depending on your hardware configuration.
Run chat frontend after starting the backend service
Chat frontend (any frontend compatible with openai API, simple options available below):
-
Option 1: Chat with Chat.py (for simple tests)
Show Option 1
Install API and chatbot dependencies (openai package is only used for local chat with candle-vllm)
python3 -m pip install openai rich click
Chat with the mini chatbot (plain text)
python3 examples/chat.py
Pass generation parameters (to reasoning models with
--thinking True)python3 examples/chat.py --temperature 0.7 --top_k 64 --top_p 0.9 --thinking True --system_prompt "Thinking big!"Chat with the mini chatbot (live update with Markdown, may cause flick)
python3 examples/chat.py --live
Details
-
Option 2: Chat with naive ChatUI (or popular dify frontend)
Show Option 2
Install naive ChatUI and its dependencies:
git clone git@github.com:guoqingbao/candle-vllm-demo.git cd candle-vllm-demo apt install npm #install npm if needed npm install n -g #update node js if needed n stable #update node js if needed npm i -g pnpm #install pnpm manager pnpm install #install ChatUI dependenciesLaunching the ChatUI:
pnpm run dev # run the ChatUITrouble shooting for Nodejs error
ENOSPC: System limit for number of file watchers reachedecho fs.inotify.max_user_watches=524288 | sudo tee -a /etc/sysctl.conf && sudo sysctl -p -
Option 3: Chat completion request with HTTP post
Show Option 3
curl -X POST "http://127.0.0.1:2000/v1/chat/completions" \ -H "Content-Type: application/json" \ -H "Authorization: Bearer YOUR_API_KEY" \ -d '{ "model": "llama7b", "messages": [ {"role": "user", "content": "Explain how to best learn Rust."} ], "temperature": 0.7, "max_tokens": 128, "stop": {"Single":"</s>"} }'
Sample response:
{"id":"cmpl-53092967-c9cf-40e0-ae26-d7ac786d59e8","choices":[{"message":{"content":" Learning any programming language requires a combination of theory, practice, and dedication. Here are some steps and resources to help you learn Rust effectively:\n\n1. Start with the basics:\n\t* Understand the syntax and basic structure of Rust programs.\n\t* Learn about variables, data types, loops, and control structures.\n\t* Familiarize yourself with Rust's ownership system and borrowing mechanism.\n2. Read the Rust book:\n\t* The Rust book is an official resource that provides a comprehensive introduction to the language.\n\t* It covers topics such","role":"[INST]"},"finish_reason":"length","index":0,"logprobs":null}],"created":1718784498,"model":"llama7b","object":"chat.completion","usage":{"completion_tokens":129,"prompt_tokens":29,"total_tokens":158}} -
Option 4: Chat completion with with openai package
Show Option 4
In your terminal, install the
openaiPython package by runningpip install openai. I use version1.3.5.Then, create a new Python file and write the following code:
import openai openai.api_key = "EMPTY" openai.base_url = "http://localhost:2000/v1/" completion = openai.chat.completions.create( model="llama", messages=[ { "role": "user", "content": "Explain how to best learn Rust.", }, ], max_tokens = 64, ) print(completion.choices[0].message.content)
After the
candle-vllmservice is running, run the Python script and enjoy efficient inference with an OpenAI compatible API server!Batched requests
Install openai API first
python3 -m pip install openaiRun the benchmark test
python3 examples/benchmark.py --batch 16 --max_tokens 1024
Refer to
examples/benchmark.pyasync def benchmark(): model = "mistral7b" max_tokens = 1024 # 16 requests prompts = ["Explain how to best learn Rust.", "Please talk about deep learning in 100 words.", "Do you know the capital city of China? Talk the details of you known.", "Who is the best female actor in the world? Explain why.", "How to dealing with depression?", "How to make money in short time?", "What is the future trend of large language model?", "The famous tech companies in the world.", "Explain how to best learn Rust.", "Please talk about deep learning in 100 words.", "Do you know the capital city of China? Talk the details of you known.", "Who is the best female actor in the world? Explain why.", "How to dealing with depression?", "How to make money in short time?", "What is the future trend of large language model?", "The famous tech companies in the world."] # send 16 chat requests at the same time tasks: List[asyncio.Task] = [] for i in range(len(prompts)): tasks.append( asyncio.create_task( chat_completion(model, max_tokens, prompts[i])) ) # obtain the corresponding stream object for each request outputs: List[Stream[ChatCompletionChunk]] = await asyncio.gather(*tasks) # tasks for streaming chat responses tasks_stream: List[asyncio.Task] = [] for i in range(len(outputs)): tasks_stream.append( asyncio.create_task( stream_response(i, outputs[i])) ) # gathering the response texts outputs: List[(int, str)] = await asyncio.gather(*tasks_stream) # print the results, you may find chat completion statistics in the backend server (i.e., candle-vllm) for idx, output in outputs: print("\n\n Response {}: \n\n {}".format(idx, output)) asyncio.run(benchmark())
-
Loading unquantized models as gguf quantized or marlin format
Show quantization config
Candle-vllm supports in-situ quantization, allowing the transformation of default weights (F32/F16/BF16) into any GGML/GGUF format, or
4-bit GPTQ/AWQweights intomarlin formatduring model loading. This feature helps conserve GPU memory and speedup inference performance, making it more efficient for consumer-grade GPUs (e.g., RTX 4090). To use this feature, simply supply theisqparameter when running candle-vllm.For unquantized models:
candle-vllm --p 2000 --w /home/Meta-Llama-3.1-8B-Instruct/ --isq q4kOptions for
isqparameters: ["q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2k", "q3k","q4k","q5k","q6k"]For quantized 4-bit GPTQ model:
candle-vllm --p 2000 --w /home/mistral_7b-int4/Please note for marlin:
-
It may takes few minutes to load F32/F16/BF16 models into quantized;
-
Marlin format in-situ conversion only support 4-bit GPTQ (with
sym=True,groupsize=128or -1,desc_act=False) and 4-bit AWQ (after conversion using the given script, refer toOther Usage); -
Marlin format only supported in CUDA platform.
-
-
KV Cache config, sampling parameter, etc.
Show details
The `--mem` (`kvcache-mem-gpu`) parameter sets a fixed KV cache budget in MB. By default this is `4096` MB.The
--gpu-memory-fractionparameter is a lighter-weight auto mode. When omitted, it defaults to0.7. After the model finishes loading, candle-vllm probes each loaded CUDA or Metal device and computes the KV cache budget as:gpu_memory_fraction * remaining_gpu_memory_after_model_loadThis means the fraction directly controls how much of the free GPU memory left after model load can be used for the combined GPU cache budget. The minimum detected budget across ranks is used as the KV cache budget per rank. For example:
candle-vllm --w /home/Qwen3-Coder-30B-A3B-Instruct-FP8 --d 0,1 --gpu-memory-fraction 0.7Use
--memwhen you want an explicit fixed budget. Use--gpu-memory-fractionwhen you want the server to adapt to the currently available GPU memory after model load.The
--enforce-parserparameter forces a specific tool-calling parser backend instead of the model-based default selection. This is useful when a model is compatible with a parser but does not get auto-detected correctly. Common values areqwen_coder,qwen,json, andmistral. For example:candle-vllm --w /home/Qwen3-Coder-30B-A3B-Instruct-FP8 --enforce-parser qwen_coderInvalid parser names are rejected at startup.
For chat history settings, set
record_conversationtotrueto let candle-vllm remember chat history. Bydefault, candle-vllmdoes notrecord chat history; instead, the client sends both the messages and the contextual history to candle-vllm. If record_conversation is set totrue, the client sends only new chat messages to candle-vllm, and candle-vllm is responsible for recording the previous chat messages. However, this approach requires per-session chat recording, which is not yet implemented, so the default approachrecord_conversation=falseis recommended.For chat streaming, the
streamflag in chat request need to be set toTrue.candle-vllm --p 2000 --w /home/mistral_7b/--max-gen-tokensparameter is used to control the maximum output tokens per chat response. The value will be set to 1/5 of max_sequence_len by default.For
consumer GPUs, it is suggested to run the models under GGML formats (or Marlin format), e.g.,candle-vllm --p 2000 --w /home/Meta-Llama-3.1-8B-Instruct/ --isq q4kwhere
isqis one of ["q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2k", "q3k","q4k","q5k","q6k", "awq", "gptq", "marlin", "gguf", "ggml"]. -
Use Marlin kernel to speedup GPTQ/AWQ models
Show details
Candle-vllm now supports GPTQ/AWQ Marlin kernel, you can run these models directly, such as:
candle-vllm --dtype f16 --w /home/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4-Marlin/
or, convert existing AWQ 4bit model to marlin compatible format
python3 examples/convert_awq_marlin.py --src /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4/ --dst /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4-Marlin/ --bits 4 --method awq --group 128 --nk False candle-vllm --dtype f16 --d 0 --w /home/Meta-Llama-3.1-8B-Instruct-AWQ-INT4-Marlin/
You may also use
GPTQModelto transform a model to marlin-compatible format using the given scriptexamples/convert_marlin.py.Note: for using Marlin fast kernel, only 4-bit GPTQ quantization supported at the moment.
Installing candle-vllm is as simple as the following steps. If you have any problems, please create an
issue.
The following features are planned to be implemented, but contributions are especially welcome:
- Sampling methods:
- Beam search (huggingface/candle#1319)
- More pipelines (from
candle-transformers)
- Python implementation:
vllm-project vllmpaper


