Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,11 +1637,6 @@ def _get_tensor_address(p):
return args

def _prepare_cp(self, *args):
# Skip CP setup if SP (Sequence Parallelism) is actually enabled (sp_size > 1)
# CP and SP are mutually exclusive
if self.parallelism_config.sp_enabled:
return args

from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method

Expand Down
8 changes: 8 additions & 0 deletions src/accelerate/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,14 @@ def __post_init__(self):
if self.sp_backend not in valid_sp_backends:
raise ValueError(f"sp_backend must be one of {valid_sp_backends}, but got {self.sp_backend}")

# CP and SP are mutually exclusive
if self.cp_size > 1 and self.sp_size > 1:
raise ValueError(
"Context Parallelism (CP) and Sequence Parallelism (SP) are mutually exclusive. "
f"Got cp_size={self.cp_size} and sp_size={self.sp_size}. "
"Please set either cp_size=1 or sp_size=1."
)

if (self.tp_size > 1 or self.cp_size > 1) and self.dp_replicate_size > 1 and self.dp_shard_size == 1:
raise ValueError(
"Tensor/Context parallelism (tp/cp_size > 1) cannot be used with pure data parallelism (dp_replicate_size > 1 and dp_shard_size == 1). "
Expand Down
Loading