Skip to content

Commit f4c1b4e

Browse files
committed
support wan
1 parent da78c5d commit f4c1b4e

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2424
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
2525
from ...utils.torch_utils import maybe_allow_in_graph
26+
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
2627
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
2728
from ..attention_dispatch import dispatch_attention_fn
2829
from ..cache_utils import CacheMixin
@@ -539,6 +540,19 @@ class WanTransformer3DModel(
539540
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
540541
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
541542
_repeated_blocks = ["WanTransformerBlock"]
543+
_cp_plan = {
544+
"rope": {
545+
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
546+
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
547+
},
548+
"blocks.0": {
549+
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
550+
},
551+
"blocks.*": {
552+
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
553+
},
554+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
555+
}
542556

543557
@register_to_config
544558
def __init__(

0 commit comments

Comments
 (0)