Skip to content

Commit 256d5a9

Browse files
committed
refactor
1 parent c777184 commit 256d5a9

File tree

4 files changed

+66
-11
lines changed

4 files changed

+66
-11
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@
215215
"MultiAdapter",
216216
"MultiControlNetModel",
217217
"OmniGenTransformer2DModel",
218+
"ParallelConfig",
218219
"PixArtTransformer2DModel",
219220
"PriorTransformer",
220221
"QwenImageTransformer2DModel",
@@ -243,6 +244,7 @@
243244
"WanTransformer3DModel",
244245
"WanVACETransformer3DModel",
245246
"attention_backend",
247+
"enable_parallelism",
246248
]
247249
)
248250
_import_structure["modular_pipelines"].extend(
@@ -879,6 +881,7 @@
879881
MultiAdapter,
880882
MultiControlNetModel,
881883
OmniGenTransformer2DModel,
884+
ParallelConfig,
882885
PixArtTransformer2DModel,
883886
PriorTransformer,
884887
QwenImageTransformer2DModel,
@@ -906,6 +909,7 @@
906909
WanTransformer3DModel,
907910
WanVACETransformer3DModel,
908911
attention_backend,
912+
enable_parallelism,
909913
)
910914
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
911915
from .optimization import (

src/diffusers/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
_import_structure = {}
2626

2727
if is_torch_available():
28-
_import_structure["_modeling_parallel"] = ["ParallelConfig"]
28+
_import_structure["_modeling_parallel"] = ["ParallelConfig", "enable_parallelism"]
2929
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
3030
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
3131
_import_structure["auto_model"] = ["AutoModel"]
@@ -115,7 +115,7 @@
115115

116116
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
117117
if is_torch_available():
118-
from ._modeling_parallel import ParallelConfig
118+
from ._modeling_parallel import ParallelConfig, enable_parallelism
119119
from .adapter import MultiAdapter, T2IAdapter
120120
from .attention_dispatch import AttentionBackendName, attention_backend
121121
from .auto_model import AutoModel

src/diffusers/models/_modeling_parallel.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,20 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18+
import contextlib
1819
from dataclasses import dataclass
19-
from typing import Dict, List, Literal, Optional, Tuple, Union
20+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
2021

2122
import torch
2223

2324
from ..utils import get_logger
2425

2526

27+
if TYPE_CHECKING:
28+
from ..pipelines.pipeline_utils import DiffusionPipeline
29+
from .modeling_utils import ModelMixin
30+
31+
2632
logger = get_logger(__name__) # pylint: disable=invalid-name
2733

2834

@@ -117,3 +123,53 @@ def __repr__(self):
117123
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
118124
# the module should be split/gathered across context parallel region.
119125
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
126+
127+
128+
_ENABLE_PARALLELISM_WARN_ONCE = False
129+
130+
131+
@contextlib.contextmanager
132+
def enable_parallelism(model_or_pipeline: Union["DiffusionPipeline", "ModelMixin"]):
133+
from diffusers import DiffusionPipeline, ModelMixin
134+
135+
from .attention_dispatch import _AttentionBackendRegistry
136+
137+
global _ENABLE_PARALLELISM_WARN_ONCE
138+
if not _ENABLE_PARALLELISM_WARN_ONCE:
139+
logger.warning(
140+
"Support for `enable_parallelism` is experimental and the API may be subject to change in the future."
141+
)
142+
_ENABLE_PARALLELISM_WARN_ONCE = True
143+
144+
if isinstance(model_or_pipeline, DiffusionPipeline):
145+
parallelized_components = [
146+
(name, component)
147+
for name, component in model_or_pipeline.components.items()
148+
if getattr(component, "_internal_parallel_config", None) is not None
149+
]
150+
if len(parallelized_components) > 1:
151+
raise ValueError(
152+
"Enabling parallelism on a pipeline is not possible when multiple internal components are parallelized. Please run "
153+
"different stages of the pipeline separately with `enable_parallelism` on each component manually."
154+
)
155+
if len(parallelized_components) == 0:
156+
raise ValueError(
157+
"No parallelized components found in the pipeline. Please ensure at least one component is parallelized."
158+
)
159+
_, model_or_pipeline = parallelized_components[0]
160+
elif isinstance(model_or_pipeline, ModelMixin):
161+
if getattr(model_or_pipeline, "_internal_parallel_config", None) is None:
162+
raise ValueError(
163+
"The model is not parallelized. Please ensure the model is parallelized with `.parallelize()` before using this context manager."
164+
)
165+
else:
166+
raise TypeError(
167+
f"Expected a `DiffusionPipeline` or `ModelMixin` instance, but got {type(model_or_pipeline)}. Please provide a valid model or pipeline."
168+
)
169+
170+
old_parallel_config = _AttentionBackendRegistry._parallel_config
171+
_AttentionBackendRegistry._parallel_config = model_or_pipeline._internal_parallel_config
172+
173+
yield
174+
175+
_AttentionBackendRegistry._parallel_config = old_parallel_config

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
249249
_skip_layerwise_casting_patterns = None
250250
_supports_group_offloading = True
251251
_repeated_blocks = []
252+
_internal_parallel_config = None
252253
_cp_plan = None
253254

254255
def __init__(self):
@@ -1480,10 +1481,8 @@ def compile_repeated_blocks(self, *args, **kwargs):
14801481
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
14811482
)
14821483

1483-
@contextmanager
14841484
def parallelize(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None):
1485-
from ..hooks.context_parallel import apply_context_parallel, remove_context_parallel
1486-
from .attention_dispatch import _AttentionBackendRegistry
1485+
from ..hooks.context_parallel import apply_context_parallel
14871486

14881487
logger.warning(
14891488
"`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
@@ -1530,11 +1529,7 @@ def parallelize(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, Con
15301529
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
15311530

15321531
apply_context_parallel(self, parallel_config, cp_plan)
1533-
_AttentionBackendRegistry._parallel_config = parallel_config
1534-
1535-
yield
1536-
1537-
remove_context_parallel(self, cp_plan)
1532+
self._internal_parallel_config = parallel_config
15381533

15391534
@classmethod
15401535
def _load_pretrained_model(

0 commit comments

Comments
 (0)