-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Fix Context Parallel validation checks #12446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
faf61a4
428399b
1d76322
a66787b
881e262
0845ca0
8018a6a
f925783
5bfc7dd
fb15ff5
56114f4
4505645
3b12a0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,29 +79,47 @@ def __post_init__(self): | |
| if self.ulysses_degree is None: | ||
| self.ulysses_degree = 1 | ||
|
|
||
| if self.ring_degree == 1 and self.ulysses_degree == 1: | ||
| raise ValueError( | ||
| "Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference" | ||
| ) | ||
| if self.ring_degree < 1 or self.ulysses_degree < 1: | ||
| raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") | ||
| if self.ring_degree > 1 and self.ulysses_degree > 1: | ||
| raise ValueError( | ||
| "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." | ||
| ) | ||
| if self.rotate_method != "allgather": | ||
| raise NotImplementedError( | ||
| f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." | ||
| ) | ||
|
|
||
| @property | ||
| def mesh_shape(self) -> Tuple[int, int]: | ||
| """Shape of the device mesh (ring_degree, ulysses_degree).""" | ||
| return (self.ring_degree, self.ulysses_degree) | ||
|
|
||
| @property | ||
| def mesh_dim_names(self) -> Tuple[str, str]: | ||
| """Dimension names for the device mesh.""" | ||
| return ("ring", "ulysses") | ||
|
|
||
| def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh): | ||
| self._rank = rank | ||
| self._world_size = world_size | ||
| self._device = device | ||
| self._mesh = mesh | ||
| if self.ring_degree is None: | ||
| self.ring_degree = 1 | ||
| if self.ulysses_degree is None: | ||
| self.ulysses_degree = 1 | ||
| if self.rotate_method != "allgather": | ||
| raise NotImplementedError( | ||
| f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." | ||
|
|
||
| if self.ulysses_degree * self.ring_degree > world_size: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we hit line as both cannot be set, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both can be set techinically, but currently both can't be > 1. Also this is for cases where you have 3 GPUs available and you set something like ulysses_degree=1 and ring_degree==4 (more GPUs being requested is greater than world_size) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feels slightly confusing to me but since we're erroring out early for unsupported |
||
| raise ValueError( | ||
| f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})." | ||
| ) | ||
| if self._flattened_mesh is None: | ||
| self._flattened_mesh = self._mesh._flatten() | ||
| if self._ring_mesh is None: | ||
| self._ring_mesh = self._mesh["ring"] | ||
| if self._ulysses_mesh is None: | ||
| self._ulysses_mesh = self._mesh["ulysses"] | ||
| if self._ring_local_rank is None: | ||
| self._ring_local_rank = self._ring_mesh.get_local_rank() | ||
| if self._ulysses_local_rank is None: | ||
| self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() | ||
|
|
||
| self._flattened_mesh = self._mesh._flatten() | ||
| self._ring_mesh = self._mesh["ring"] | ||
| self._ulysses_mesh = self._mesh["ulysses"] | ||
| self._ring_local_rank = self._ring_mesh.get_local_rank() | ||
| self._ulysses_local_rank = self._ulysses_mesh.get_local_rank() | ||
|
Comment on lines
+122
to
+126
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't they be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are internal attributes that are derived from mesh which is set through the The guards are redundant, they would always be |
||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -119,22 +137,22 @@ class ParallelConfig: | |
| _rank: int = None | ||
| _world_size: int = None | ||
| _device: torch.device = None | ||
| _cp_mesh: torch.distributed.device_mesh.DeviceMesh = None | ||
| _mesh: torch.distributed.device_mesh.DeviceMesh = None | ||
|
|
||
| def setup( | ||
| self, | ||
| rank: int, | ||
| world_size: int, | ||
| device: torch.device, | ||
| *, | ||
| cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | ||
| mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, | ||
| ): | ||
| self._rank = rank | ||
| self._world_size = world_size | ||
| self._device = device | ||
| self._cp_mesh = cp_mesh | ||
| self._mesh = mesh | ||
| if self.context_parallel_config is not None: | ||
| self.context_parallel_config.setup(rank, world_size, device, cp_mesh) | ||
| self.context_parallel_config.setup(rank, world_size, device, mesh) | ||
|
|
||
|
|
||
| @dataclass(frozen=True) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -220,7 +220,7 @@ class _AttentionBackendRegistry: | |
| _backends = {} | ||
| _constraints = {} | ||
| _supported_arg_names = {} | ||
| _supports_context_parallel = {} | ||
| _supports_context_parallel = set() | ||
| _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) | ||
| _checks_enabled = DIFFUSERS_ATTN_CHECKS | ||
|
|
||
|
|
@@ -237,7 +237,9 @@ def decorator(func): | |
| cls._backends[backend] = func | ||
| cls._constraints[backend] = constraints or [] | ||
| cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) | ||
| cls._supports_context_parallel[backend] = supports_context_parallel | ||
| if supports_context_parallel: | ||
| cls._supports_context_parallel.add(backend.value) | ||
|
|
||
| return func | ||
|
|
||
| return decorator | ||
|
|
@@ -251,15 +253,12 @@ def list_backends(cls): | |
| return list(cls._backends.keys()) | ||
|
|
||
| @classmethod | ||
| def _is_context_parallel_enabled( | ||
| cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"] | ||
| def _is_context_parallel_available( | ||
| cls, | ||
| backend: AttentionBackendName, | ||
| ) -> bool: | ||
| supports_context_parallel = backend in cls._supports_context_parallel | ||
| is_degree_greater_than_1 = parallel_config is not None and ( | ||
| parallel_config.context_parallel_config.ring_degree > 1 | ||
| or parallel_config.context_parallel_config.ulysses_degree > 1 | ||
| ) | ||
| return supports_context_parallel and is_degree_greater_than_1 | ||
| supports_context_parallel = backend.value in cls._supports_context_parallel | ||
| return supports_context_parallel | ||
|
Comment on lines
-257
to
+261
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice cleanup here! |
||
|
|
||
|
|
||
| @contextlib.contextmanager | ||
|
|
@@ -306,14 +305,6 @@ def dispatch_attention_fn( | |
| backend_name = AttentionBackendName(backend) | ||
| backend_fn = _AttentionBackendRegistry._backends.get(backend_name) | ||
|
|
||
| if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled( | ||
| backend_name, parallel_config | ||
| ): | ||
| raise ValueError( | ||
| f"Backend {backend_name} either does not support context parallelism or context parallelism " | ||
| f"was enabled with a world size of 1." | ||
| ) | ||
|
|
||
| kwargs = { | ||
| "query": query, | ||
| "key": key, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1484,59 +1484,72 @@ def enable_parallelism( | |
| config: Union[ParallelConfig, ContextParallelConfig], | ||
| cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None, | ||
| ): | ||
| from ..hooks.context_parallel import apply_context_parallel | ||
| from .attention import AttentionModuleMixin | ||
| from .attention_processor import Attention, MochiAttention | ||
|
|
||
| logger.warning( | ||
| "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning." | ||
| ) | ||
|
|
||
| if not torch.distributed.is_available() and not torch.distributed.is_initialized(): | ||
| raise RuntimeError( | ||
| "torch.distributed must be available and initialized before calling `enable_parallelism`." | ||
| ) | ||
|
|
||
| from ..hooks.context_parallel import apply_context_parallel | ||
| from .attention import AttentionModuleMixin | ||
| from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry | ||
| from .attention_processor import Attention, MochiAttention | ||
|
|
||
| if isinstance(config, ContextParallelConfig): | ||
| config = ParallelConfig(context_parallel_config=config) | ||
|
|
||
| if not torch.distributed.is_initialized(): | ||
| raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.") | ||
|
|
||
| rank = torch.distributed.get_rank() | ||
| world_size = torch.distributed.get_world_size() | ||
| device_type = torch._C._get_accelerator().type | ||
| device_module = torch.get_device_module(device_type) | ||
| device = torch.device(device_type, rank % device_module.device_count()) | ||
|
|
||
| cp_mesh = None | ||
| attention_classes = (Attention, MochiAttention, AttentionModuleMixin) | ||
|
|
||
| # Step 1: Validate attention backend supports context parallelism if enabled | ||
|
||
| if config.context_parallel_config is not None: | ||
| cp_config = config.context_parallel_config | ||
| if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1: | ||
| raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") | ||
| if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1: | ||
| raise ValueError( | ||
| "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." | ||
| ) | ||
| if cp_config.ring_degree * cp_config.ulysses_degree > world_size: | ||
| raise ValueError( | ||
| f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})." | ||
| ) | ||
| cp_mesh = torch.distributed.device_mesh.init_device_mesh( | ||
| device_type=device_type, | ||
| mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree), | ||
| mesh_dim_names=("ring", "ulysses"), | ||
| ) | ||
| for module in self.modules(): | ||
| if not isinstance(module, attention_classes): | ||
| continue | ||
|
|
||
| config.setup(rank, world_size, device, cp_mesh=cp_mesh) | ||
| processor = module.processor | ||
| if processor is None or not hasattr(processor, "_attention_backend"): | ||
| continue | ||
|
|
||
| if cp_plan is None and self._cp_plan is None: | ||
| raise ValueError( | ||
| "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." | ||
| ) | ||
| cp_plan = cp_plan if cp_plan is not None else self._cp_plan | ||
| attention_backend = processor._attention_backend | ||
| if attention_backend is None: | ||
| attention_backend, _ = _AttentionBackendRegistry.get_active_backend() | ||
| else: | ||
| attention_backend = AttentionBackendName(attention_backend) | ||
|
|
||
| if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend): | ||
| compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel) | ||
| raise ValueError( | ||
| f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' " | ||
| f"is using backend '{attention_backend.value}' which does not support context parallelism. " | ||
| f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before " | ||
| f"calling `enable_parallelism()`." | ||
| ) | ||
|
|
||
| # All modules use the same attention processor and backend. We don't need to | ||
| # iterate over all modules after checking the first processor | ||
| break | ||
|
|
||
| mesh = None | ||
| if config.context_parallel_config is not None: | ||
| apply_context_parallel(self, config.context_parallel_config, cp_plan) | ||
| cp_config = config.context_parallel_config | ||
| mesh = torch.distributed.device_mesh.init_device_mesh( | ||
| device_type=device_type, | ||
| mesh_shape=cp_config.mesh_shape, | ||
| mesh_dim_names=cp_config.mesh_dim_names, | ||
| ) | ||
|
|
||
| config.setup(rank, world_size, device, mesh=mesh) | ||
| self._parallel_config = config | ||
|
|
||
| attention_classes = (Attention, MochiAttention, AttentionModuleMixin) | ||
| for module in self.modules(): | ||
| if not isinstance(module, attention_classes): | ||
| continue | ||
|
|
@@ -1545,6 +1558,14 @@ def enable_parallelism( | |
| continue | ||
| processor._parallel_config = config | ||
|
|
||
| if config.context_parallel_config is not None: | ||
| if cp_plan is None and self._cp_plan is None: | ||
| raise ValueError( | ||
| "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute." | ||
| ) | ||
| cp_plan = cp_plan if cp_plan is not None else self._cp_plan | ||
| apply_context_parallel(self, config.context_parallel_config, cp_plan) | ||
|
|
||
| @classmethod | ||
| def _load_pretrained_model( | ||
| cls, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to add a small explainer about what it would mean for different values, for example - "(3, 1), (1, 3)", etc.? When both are being set, both cannot be > 1.