|  | 
| 15 | 15 | # See the License for the specific language governing permissions and | 
| 16 | 16 | # limitations under the License. | 
| 17 | 17 | 
 | 
|  | 18 | +import contextlib | 
| 18 | 19 | from dataclasses import dataclass | 
| 19 |  | -from typing import Dict, List, Literal, Optional, Tuple, Union | 
|  | 20 | +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union | 
| 20 | 21 | 
 | 
| 21 | 22 | import torch | 
| 22 | 23 | 
 | 
| 23 | 24 | from ..utils import get_logger | 
| 24 | 25 | 
 | 
| 25 | 26 | 
 | 
|  | 27 | +if TYPE_CHECKING: | 
|  | 28 | +    from ..pipelines.pipeline_utils import DiffusionPipeline | 
|  | 29 | +    from .modeling_utils import ModelMixin | 
|  | 30 | + | 
|  | 31 | + | 
| 26 | 32 | logger = get_logger(__name__)  # pylint: disable=invalid-name | 
| 27 | 33 | 
 | 
| 28 | 34 | 
 | 
| @@ -117,3 +123,53 @@ def __repr__(self): | 
| 117 | 123 | # A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of | 
| 118 | 124 | # the module should be split/gathered across context parallel region. | 
| 119 | 125 | ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]] | 
|  | 126 | + | 
|  | 127 | + | 
|  | 128 | +_ENABLE_PARALLELISM_WARN_ONCE = False | 
|  | 129 | + | 
|  | 130 | + | 
|  | 131 | +@contextlib.contextmanager | 
|  | 132 | +def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]): | 
|  | 133 | +    from diffusers import DiffusionPipeline, ModelMixin | 
|  | 134 | + | 
|  | 135 | +    from .attention_dispatch import _AttentionBackendRegistry | 
|  | 136 | + | 
|  | 137 | +    global _ENABLE_PARALLELISM_WARN_ONCE | 
|  | 138 | +    if not _ENABLE_PARALLELISM_WARN_ONCE: | 
|  | 139 | +        logger.warning( | 
|  | 140 | +            "Support for `enable_parallelism` is experimental and the API may be subject to change in the future." | 
|  | 141 | +        ) | 
|  | 142 | +        _ENABLE_PARALLELISM_WARN_ONCE = True | 
|  | 143 | + | 
|  | 144 | +    if isinstance(model_or_pipeline, DiffusionPipeline): | 
|  | 145 | +        parallelized_components = [ | 
|  | 146 | +            (name, component) | 
|  | 147 | +            for name, component in model_or_pipeline.components.items() | 
|  | 148 | +            if getattr(component, "_internal_parallel_config", None) is not None | 
|  | 149 | +        ] | 
|  | 150 | +        if len(parallelized_components) > 1: | 
|  | 151 | +            raise ValueError( | 
|  | 152 | +                "Enabling parallelism on a pipeline is not possible when multiple internal components are parallelized. Please run " | 
|  | 153 | +                "different stages of the pipeline separately with `enable_parallelism` on each component manually." | 
|  | 154 | +            ) | 
|  | 155 | +        if len(parallelized_components) == 0: | 
|  | 156 | +            raise ValueError( | 
|  | 157 | +                "No parallelized components found in the pipeline. Please ensure at least one component is parallelized." | 
|  | 158 | +            ) | 
|  | 159 | +        _, model_or_pipeline = parallelized_components[0] | 
|  | 160 | +    elif isinstance(model_or_pipeline, ModelMixin): | 
|  | 161 | +        if getattr(model_or_pipeline, "_internal_parallel_config", None) is None: | 
|  | 162 | +            raise ValueError( | 
|  | 163 | +                "The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager." | 
|  | 164 | +            ) | 
|  | 165 | +    else: | 
|  | 166 | +        raise TypeError( | 
|  | 167 | +            f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got {type(model_or_pipeline)}. Please provide a valid model or pipeline." | 
|  | 168 | +        ) | 
|  | 169 | + | 
|  | 170 | +    old_parallel_config = _AttentionBackendRegistry._parallel_config | 
|  | 171 | +    _AttentionBackendRegistry._parallel_config = model_or_pipeline._internal_parallel_config | 
|  | 172 | + | 
|  | 173 | +    yield | 
|  | 174 | + | 
|  | 175 | +    _AttentionBackendRegistry._parallel_config = old_parallel_config | 
0 commit comments