Skip to content

Commit 7c16f3f

Browse files
authored
[Doc] Add documents for multi-node distributed serving with MP backend (vllm-project#30509)
Signed-off-by: Isotr0py <[email protected]>
1 parent ddbfbe5 commit 7c16f3f

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

docs/serving/parallelism_scaling.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ If a single node lacks sufficient GPUs to hold the model, deploy vLLM across mul
6262

6363
### What is Ray?
6464

65-
Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments require Ray as the runtime engine.
65+
Ray is a distributed computing framework for scaling Python programs. Multi-node vLLM deployments can use Ray as the runtime engine.
6666

6767
vLLM uses Ray to manage the distributed execution of tasks across multiple nodes and control where execution happens.
6868

@@ -130,6 +130,28 @@ vllm serve /path/to/the/model/in/the/container \
130130
--distributed-executor-backend ray
131131
```
132132

133+
### Running vLLM with MultiProcessing
134+
135+
Besides Ray, Multi-node vLLM deployments can also use `multiprocessing` as the runtime engine. Here's an example to deploy model across 2 nodes (8 GPUs per node) with `tp_size=8` and `pp_size=2`.
136+
137+
Choose one node as the head node and run:
138+
139+
```bash
140+
vllm serve /path/to/the/model/in/the/container \
141+
--tensor-parallel-size 8 --pipeline-parallel-size 2 \
142+
--nnodes 2 --node-rank 0 \
143+
--master-addr <HEAD_NODE_IP>
144+
```
145+
146+
On the other worker node, run:
147+
148+
```bash
149+
vllm serve /path/to/the/model/in/the/container \
150+
--tensor-parallel-size 8 --pipeline-parallel-size 2 \
151+
--nnodes 2 --node-rank 1 \
152+
--master-addr <HEAD_NODE_IP> --headless
153+
```
154+
133155
## Optimizing network communication for tensor parallelism
134156

135157
Efficient tensor parallelism requires fast inter-node communication, preferably through high-speed network adapters such as InfiniBand.

vllm/v1/executor/multiproc_executor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ def _init_executor(self) -> None:
124124
# Set multiprocessing envs
125125
set_multiprocessing_worker_envs()
126126

127-
# Multiprocessing-based executor does not support multi-node setting.
128-
# Since it only works for single node, we can use the loopback address
129-
# get_loopback_ip() for communication.
127+
# use the loopback address get_loopback_ip() for communication.
130128
distributed_init_method = get_distributed_init_method(
131129
get_loopback_ip(), get_open_port()
132130
)

0 commit comments

Comments
 (0)