Skip to content

Commit 1787198

Browse files
authored
[Bugfix] Fix sequence parallelism bug when enable pipeline parallelism (vllm-project#24021)
Signed-off-by: cascade812 <[email protected]>
1 parent 759ef49 commit 1787198

File tree

6 files changed

+135
-42
lines changed

6 files changed

+135
-42
lines changed

tests/distributed/test_sequence_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def _compare_sp(
235235
'level': 3,
236236
'custom_ops': ["+rms_norm"],
237237
'compile_sizes': [4, 8],
238-
'splitting_ops': [],
239238
'pass_config': {
240239
'enable_sequence_parallelism': True,
241240
'enable_fusion': enable_fusion,
@@ -251,6 +250,8 @@ def _compare_sp(
251250
*common_args,
252251
"--tensor-parallel-size",
253252
str(tp_size),
253+
"--pipeline-parallel-size",
254+
str(pp_size),
254255
"--distributed-executor-backend",
255256
distributed_backend,
256257
"--compilation_config",

vllm/distributed/parallel_state.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -663,14 +663,29 @@ def send_tensor_dict(
663663
tensor_dict: dict[str, Union[torch.Tensor, Any]],
664664
dst: Optional[int] = None,
665665
all_gather_group: Optional["GroupCoordinator"] = None,
666+
all_gather_tensors: Optional[dict[str, bool]] = None,
666667
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
667668
"""Send the input tensor dictionary.
668669
NOTE: `dst` is the local rank of the source rank.
670+
671+
all_gather_group: The group for the all-gather operation. If provided,
672+
an optimization is enabled where each rank in the group sends a
673+
slice of a tensor and the receiver reconstructs it using an
674+
all-gather, which can improve performance. This is typically the
675+
tensor-parallel group.
676+
all_gather_tensors: A dictionary to specify which tensors should use
677+
the all-gather optimization, which is only effective when
678+
`all_gather_group` is provided. By default, this optimization is
679+
on for any tensor whose size is divisible by the
680+
`all_gather_group`'s world size. However, it should be disabled
681+
for tensors that are not fully replicated across the group (e.g.,
682+
the residual tensor when sequence parallelism is enabled). This
683+
dictionary allows overriding the default behavior on a per-tensor
684+
basis.
669685
"""
670686
# Bypass the function if we are using only 1 GPU.
671687
if not torch.distributed.is_initialized() or self.world_size == 1:
672688
return tensor_dict
673-
674689
all_gather_size = (1 if all_gather_group is None else
675690
all_gather_group.world_size)
676691
all_gather_rank = (0 if all_gather_group is None else
@@ -699,14 +714,23 @@ def send_tensor_dict(
699714
# `send_object_list` has serialization & deserialization,
700715
# all happening on CPU. Therefore, we can use the CPU group.
701716
self.send_object(metadata_list, dst=dst)
702-
for tensor in tensor_list:
717+
718+
tensor_keys = [
719+
k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)
720+
]
721+
assert len(tensor_keys) == len(tensor_list)
722+
723+
for key, tensor in zip(tensor_keys, tensor_list):
703724
if tensor.numel() == 0:
704725
# Skip sending empty tensors.
705726
continue
706727

707728
# send-allgather: send only a slice, then do allgather.
708-
if (all_gather_group is not None
709-
and tensor.numel() % all_gather_size == 0):
729+
use_all_gather = (all_gather_group is not None
730+
and tensor.numel() % all_gather_size == 0)
731+
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
732+
if all_gather_tensors else use_all_gather
733+
if use_all_gather:
710734
tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
711735

712736
if tensor.is_cpu:
@@ -725,14 +749,29 @@ def recv_tensor_dict(
725749
self,
726750
src: Optional[int] = None,
727751
all_gather_group: Optional["GroupCoordinator"] = None,
752+
all_gather_tensors: Optional[dict[str, bool]] = None,
728753
) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
729754
"""Recv the input tensor dictionary.
730755
NOTE: `src` is the local rank of the source rank.
756+
757+
all_gather_group: The group for the all-gather operation. If provided,
758+
an optimization is enabled where each rank in the group sends a
759+
slice of a tensor and the receiver reconstructs it using an
760+
all-gather, which can improve performance. This is typically the
761+
tensor-parallel group.
762+
all_gather_tensors: A dictionary to specify which tensors should use
763+
the all-gather optimization, which is only effective when
764+
`all_gather_group` is provided. By default, this optimization is
765+
on for any tensor whose size is divisible by the
766+
`all_gather_group`'s world size. However, it should be disabled
767+
for tensors that are not fully replicated across the group (e.g.,
768+
the residual tensor when sequence parallelism is enabled). This
769+
dictionary allows overriding the default behavior on a per-tensor
770+
basis.
731771
"""
732772
# Bypass the function if we are using only 1 GPU.
733773
if not torch.distributed.is_initialized() or self.world_size == 1:
734774
return None
735-
736775
all_gather_size = (1 if all_gather_group is None else
737776
all_gather_group.world_size)
738777
all_gather_rank = (0 if all_gather_group is None else
@@ -766,6 +805,8 @@ def recv_tensor_dict(
766805
# send-allgather: send only a slice, then do allgather.
767806
use_all_gather = (all_gather_group is not None
768807
and tensor.numel() % all_gather_size == 0)
808+
use_all_gather = all_gather_tensors.get(key, use_all_gather) \
809+
if all_gather_tensors else use_all_gather
769810

770811
if use_all_gather:
771812
orig_shape = tensor.shape

vllm/v1/worker/cpu_worker.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
2020
from vllm.v1.worker.gpu_worker import (Worker,
2121
init_worker_distributed_environment)
22+
from vllm.v1.worker.utils import is_residual_scattered_for_sp
2223

2324
logger = init_logger(__name__)
2425

@@ -107,18 +108,29 @@ def execute_model(
107108
scheduler_output: "SchedulerOutput",
108109
) -> Optional[ModelRunnerOutput]:
109110
intermediate_tensors = None
111+
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
112+
num_input_tokens = self.model_runner._get_num_input_tokens(
113+
num_scheduled_tokens)
114+
all_gather_tensors = {
115+
"residual":
116+
not is_residual_scattered_for_sp(self.vllm_config,
117+
num_input_tokens)
118+
}
110119
if not get_pp_group().is_first_rank:
111120
intermediate_tensors = IntermediateTensors(
112121
get_pp_group().recv_tensor_dict(
113-
all_gather_group=get_tp_group()))
122+
all_gather_group=get_tp_group(),
123+
all_gather_tensors=all_gather_tensors))
114124

115125
output = self.model_runner.execute_model(scheduler_output,
116126
intermediate_tensors)
117127

118128
if not get_pp_group().is_last_rank:
119129
assert isinstance(output, IntermediateTensors)
120-
get_pp_group().send_tensor_dict(output.tensors,
121-
all_gather_group=get_tp_group())
130+
get_pp_group().send_tensor_dict(
131+
output.tensors,
132+
all_gather_group=get_tp_group(),
133+
all_gather_tensors=all_gather_tensors)
122134
return None
123135

124136
assert isinstance(output, ModelRunnerOutput)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
from vllm.v1.worker.kv_connector_model_runner_mixin import (
8989
KVConnectorModelRunnerMixin)
9090
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
91+
from vllm.v1.worker.utils import is_residual_scattered_for_sp
9192

9293
from .utils import (AttentionGroup, MultiModalBudget,
9394
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
@@ -1633,30 +1634,23 @@ def sync_and_slice_intermediate_tensors(
16331634
assert self.intermediate_tensors is not None
16341635

16351636
tp = self.vllm_config.parallel_config.tensor_parallel_size
1636-
enabled_sp = self.compilation_config.pass_config. \
1637-
enable_sequence_parallelism
1638-
if enabled_sp:
1639-
# When sequence parallelism is enabled, we always pad num_tokens
1640-
# to be a multiple of tensor_parallel_size (tp) earlier
1641-
assert num_tokens % tp == 0
1642-
is_residual_scattered = tp > 1 and enabled_sp \
1643-
and num_tokens % tp == 0
1637+
is_rs = is_residual_scattered_for_sp(self.vllm_config, num_tokens)
16441638

16451639
# When sequence parallelism is enabled, the "residual" tensor is sharded
16461640
# across tensor parallel ranks, so each rank only needs its own slice.
16471641
if sync_self:
16481642
assert intermediate_tensors is not None
16491643
for k, v in intermediate_tensors.items():
1650-
is_scattered = k == "residual" and is_residual_scattered
1644+
is_scattered = k == "residual" and is_rs
16511645
copy_len = num_tokens // tp if is_scattered else \
16521646
num_tokens
16531647
self.intermediate_tensors[k][:copy_len].copy_(
16541648
v[:copy_len], non_blocking=True)
16551649

16561650
return IntermediateTensors({
16571651
k:
1658-
v[:num_tokens // tp]
1659-
if k == "residual" and is_residual_scattered else v[:num_tokens]
1652+
v[:num_tokens //
1653+
tp] if k == "residual" and is_rs else v[:num_tokens]
16601654
for k, v in self.intermediate_tensors.items()
16611655
})
16621656

@@ -1741,6 +1735,25 @@ def _pool(
17411735
pooler_output=pooler_output,
17421736
)
17431737

1738+
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
1739+
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
1740+
and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH
1741+
and hasattr(self, "cudagraph_batch_sizes")
1742+
and self.cudagraph_batch_sizes
1743+
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
1744+
# Use CUDA graphs.
1745+
# Add padding to the batch size.
1746+
return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
1747+
1748+
# Eager mode.
1749+
# Pad tokens to multiple of tensor_parallel_size when
1750+
# enabled collective fusion for SP
1751+
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
1752+
if (self.compilation_config.pass_config.enable_sequence_parallelism
1753+
and tp_size > 1):
1754+
return round_up(num_scheduled_tokens, tp_size)
1755+
return num_scheduled_tokens
1756+
17441757
def _preprocess(
17451758
self,
17461759
scheduler_output: "SchedulerOutput",
@@ -1750,24 +1763,7 @@ def _preprocess(
17501763
Optional[IntermediateTensors], dict[str, Any]]:
17511764

17521765
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
1753-
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
1754-
and not envs.VLLM_DISABLE_PAD_FOR_CUDAGRAPH
1755-
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
1756-
# Use CUDA graphs.
1757-
# Add padding to the batch size.
1758-
num_input_tokens = self.vllm_config.pad_for_cudagraph(
1759-
num_scheduled_tokens)
1760-
else:
1761-
# Eager mode.
1762-
# Pad tokens to multiple of tensor_parallel_size when
1763-
# enabled collective fusion for SP
1764-
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
1765-
if self.compilation_config.pass_config. \
1766-
enable_sequence_parallelism and tp_size > 1:
1767-
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
1768-
else:
1769-
num_input_tokens = num_scheduled_tokens
1770-
1766+
num_input_tokens = self._get_num_input_tokens(num_scheduled_tokens)
17711767
# Padding for DP
17721768
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
17731769
num_input_tokens += num_pad
@@ -2108,8 +2104,15 @@ def execute_model(
21082104
assert not self.is_pooling_model
21092105

21102106
if not get_pp_group().is_last_rank:
2107+
all_gather_tensors = {
2108+
"residual":
2109+
not is_residual_scattered_for_sp(
2110+
self.vllm_config, num_input_tokens)
2111+
}
21112112
get_pp_group().send_tensor_dict(
2112-
hidden_states.tensors, all_gather_group=get_tp_group())
2113+
hidden_states.tensors,
2114+
all_gather_group=get_tp_group(),
2115+
all_gather_tensors=all_gather_tensors)
21132116
logits = None
21142117
else:
21152118
sample_hidden_states = hidden_states[logits_indices]

vllm/v1/worker/gpu_worker.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
DraftTokenIds, ModelRunnerOutput)
3333
from vllm.v1.utils import report_usage_stats
3434
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
35+
from vllm.v1.worker.utils import is_residual_scattered_for_sp
3536
from vllm.v1.worker.worker_base import WorkerBase
3637

3738
logger = init_logger(__name__)
@@ -428,10 +429,19 @@ def execute_model(
428429
) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]:
429430
intermediate_tensors = None
430431
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
432+
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
433+
num_input_tokens = self.model_runner._get_num_input_tokens(
434+
num_scheduled_tokens)
435+
all_gather_tensors = {
436+
"residual":
437+
not is_residual_scattered_for_sp(self.vllm_config,
438+
num_input_tokens)
439+
}
431440
if forward_pass and not get_pp_group().is_first_rank:
432441
intermediate_tensors = IntermediateTensors(
433442
get_pp_group().recv_tensor_dict(
434-
all_gather_group=get_tp_group()))
443+
all_gather_group=get_tp_group(),
444+
all_gather_tensors=all_gather_tensors))
435445

436446
output = self.model_runner.execute_model(scheduler_output,
437447
intermediate_tensors)
@@ -444,7 +454,8 @@ def execute_model(
444454
"external_launcher") and not get_pp_group().is_last_rank
445455

446456
get_pp_group().send_tensor_dict(output.tensors,
447-
all_gather_group=get_tp_group())
457+
all_gather_group=get_tp_group(),
458+
all_gather_tensors=all_gather_tensors)
448459

449460
kv_connector_output = output.kv_connector_output
450461
if not kv_connector_output:

vllm/v1/worker/utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88

99
from vllm.attention.backends.abstract import AttentionBackend
10-
from vllm.config import ModelConfig, SchedulerConfig
10+
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
1111
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
1212
from vllm.model_executor.models.utils import extract_layer_index
1313
from vllm.multimodal.cache import processor_only_cache_from_config
@@ -288,3 +288,28 @@ def bind_kv_cache(
288288
for layer_name, kv_cache in kv_caches.items():
289289
# NOTE: Use list because of v0 PP virtual engine.
290290
forward_context[layer_name].kv_cache = [kv_cache]
291+
292+
293+
def is_residual_scattered_for_sp(vllm_config: VllmConfig,
294+
num_input_tokens: int) -> bool:
295+
"""Check if the residual tensor is scattered for sequence parallelism.
296+
297+
The residual tensor is scattered across tensor parallel ranks when sequence
298+
parallelism and tensor parallelism is enabled, and the number of
299+
input tokens is one of the compilation sizes.
300+
"""
301+
if not vllm_config.compilation_config.pass_config.\
302+
enable_sequence_parallelism:
303+
return False
304+
305+
tp = vllm_config.parallel_config.tensor_parallel_size
306+
307+
if tp == 1:
308+
return False
309+
310+
# When sequence parallelism is enabled, we always pad num_input_tokens
311+
# to be a multiple of tensor_parallel_size (tp) earlier.
312+
assert num_input_tokens % tp == 0
313+
314+
# Currently, SP is only enabled for static size fx graphs.
315+
return (num_input_tokens in vllm_config.compilation_config.compile_sizes)

0 commit comments

Comments
 (0)