|  | 
| 23 | 23 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin | 
| 24 | 24 | from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers | 
| 25 | 25 | from ...utils.torch_utils import maybe_allow_in_graph | 
|  | 26 | +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput | 
| 26 | 27 | from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward | 
| 27 | 28 | from ..attention_dispatch import dispatch_attention_fn | 
| 28 | 29 | from ..cache_utils import CacheMixin | 
| @@ -539,6 +540,19 @@ class WanTransformer3DModel( | 
| 539 | 540 |     _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] | 
| 540 | 541 |     _keys_to_ignore_on_load_unexpected = ["norm_added_q"] | 
| 541 | 542 |     _repeated_blocks = ["WanTransformerBlock"] | 
|  | 543 | +    _cp_plan = { | 
|  | 544 | +        "rope": { | 
|  | 545 | +            0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), | 
|  | 546 | +            1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True), | 
|  | 547 | +        }, | 
|  | 548 | +        "blocks.0": { | 
|  | 549 | +            "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), | 
|  | 550 | +        }, | 
|  | 551 | +        "blocks.*": { | 
|  | 552 | +            "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), | 
|  | 553 | +        }, | 
|  | 554 | +        "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), | 
|  | 555 | +    } | 
| 542 | 556 | 
 | 
| 543 | 557 |     @register_to_config | 
| 544 | 558 |     def __init__( | 
|  | 
0 commit comments