Skip to content

Commit ad39106

Browse files
authored
[CPU] Enable data parallel for CPU backend (vllm-project#23903)
Signed-off-by: jiang1.li <[email protected]>
1 parent 2554b27 commit ad39106

File tree

6 files changed

+48
-9
lines changed

6 files changed

+48
-9
lines changed

.buildkite/scripts/hardware_ci/run-cpu-test.sh

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
2525
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
2626

2727
# Run the image, setting --shm-size=4g for tensor parallel.
28-
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
29-
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
28+
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
29+
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=16 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
3030

3131
function cpu_tests() {
3232
set -e
@@ -89,17 +89,33 @@ function cpu_tests() {
8989
pytest -x -s -v \
9090
tests/lora/test_qwen2vl.py"
9191

92-
# online serving
92+
# online serving: tp+pp
9393
docker exec cpu-test-"$NUMA_NODE" bash -c '
9494
set -e
9595
VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
96+
server_pid=$!
9697
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
9798
vllm bench serve \
9899
--backend vllm \
99100
--dataset-name random \
100101
--model meta-llama/Llama-3.2-3B-Instruct \
101102
--num-prompts 20 \
102-
--endpoint /v1/completions'
103+
--endpoint /v1/completions
104+
kill -s SIGTERM $server_pid &'
105+
106+
# online serving: tp+dp
107+
docker exec cpu-test-"$NUMA_NODE" bash -c '
108+
set -e
109+
VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 &
110+
server_pid=$!
111+
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
112+
vllm bench serve \
113+
--backend vllm \
114+
--dataset-name random \
115+
--model meta-llama/Llama-3.2-3B-Instruct \
116+
--num-prompts 20 \
117+
--endpoint /v1/completions
118+
kill -s SIGTERM $server_pid &'
103119
}
104120

105121
# All of CPU tests are expected to be finished less than 40 mins.

docs/getting_started/installation/cpu.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ Currently, there are no pre-built CPU wheels.
9696
- `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`.
9797
- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively.
9898
- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`.
99+
- `CPU_VISIBLE_MEMORY_NODES`: specify visible NUMA memory nodes for vLLM CPU workers, similar to ```CUDA_VISIBLE_DEVICES```. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. The variable provides more control for the auto thread-binding feature, such as masking nodes and changing nodes binding sequence.
99100
- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False).
100101
- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False).
101102

@@ -179,7 +180,7 @@ Inference batch size is an important parameter for the performance. Larger batch
179180
- Offline Inference: `256 * world_size`
180181
- Online Serving: `128 * world_size`
181182

182-
vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes.
183+
vLLM CPU supports data parallel (DP), tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning DP, TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use DP, TP and PP together if there are enough CPU sockets and memory nodes.
183184

184185
### Which quantization configs does vLLM CPU support?
185186

docs/getting_started/installation/cpu/x86.inc.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ docker build -f docker/Dockerfile.cpu \
4343

4444
# Launching OpenAI server
4545
docker run --rm \
46-
--privileged=true \
46+
--security-opt seccomp=unconfined \
4747
--shm-size=4g \
4848
-p 8000:8000 \
4949
-e VLLM_CPU_KVCACHE_SPACE=<KV cache space> \

vllm/platforms/cpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class CpuPlatform(Platform):
6969
device_type: str = "cpu"
7070
dispatch_key: str = "CPU"
7171
dist_backend: str = "gloo"
72+
device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
7273

7374
@property
7475
def supported_dtypes(self) -> list[torch.dtype]:
@@ -297,6 +298,13 @@ def get_allowed_cpu_core_node_list(
297298
allowed_numa_nodes.add(x.numa_node) # type: ignore
298299
allowed_numa_nodes_list = sorted(allowed_numa_nodes)
299300

301+
env_key = CpuPlatform.device_control_env_var
302+
if (env_key in os.environ and os.environ[env_key] != ""):
303+
visible_nodes = [int(s) for s in os.environ[env_key].split(',')]
304+
allowed_numa_nodes_list = [
305+
x for x in visible_nodes if x in allowed_cpu_id_list
306+
]
307+
300308
return allowed_numa_nodes_list, logical_cpu_list
301309

302310
@classmethod

vllm/v1/worker/cpu_model_runner.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from contextlib import contextmanager
4-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, Optional
55

66
import torch
77
import torch.nn as nn
@@ -113,6 +113,11 @@ def _sync_device(self) -> None:
113113
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
114114
return sampled_token_ids.tolist()
115115

116+
def get_dp_padding(self,
117+
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
118+
# Note: For CPU backend, dp padding is not required for now.
119+
return 0, None
120+
116121

117122
@contextmanager
118123
def _torch_cuda_wrapper():

vllm/v1/worker/cpu_worker.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@ def init_device(self):
5555
else:
5656
self.local_omp_cpuid = "all"
5757
else:
58-
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
58+
local_dp_rank = self.parallel_config.data_parallel_rank_local
59+
omp_cpuids = omp_cpuids.split("|")
60+
if local_dp_rank is not None:
61+
world_size = self.parallel_config.world_size
62+
omp_cpuids = omp_cpuids[local_dp_rank *
63+
world_size:(local_dp_rank + 1) *
64+
world_size]
65+
self.local_omp_cpuid = omp_cpuids[self.rank]
5966

6067
if self.local_omp_cpuid != "all":
6168
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
@@ -162,7 +169,9 @@ def _get_autobind_cpu_ids(
162169
# Reserve CPUs for other processes
163170
reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU
164171
if reserve_cpu_num is None:
165-
reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0
172+
need_reserve = (self.parallel_config.world_size > 1 or
173+
self.parallel_config.data_parallel_size_local > 1)
174+
reserve_cpu_num = 1 if need_reserve else 0
166175
assert len(logical_cpu_list) > reserve_cpu_num, (
167176
f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) "
168177
f"should less than {len(logical_cpu_list)}.")

0 commit comments

Comments
 (0)