|
25 | 25 | from ...loaders import FromOriginalModelMixin, PeftAdapterMixin |
26 | 26 | from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers |
27 | 27 | from ...utils.torch_utils import maybe_allow_in_graph |
| 28 | +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput |
28 | 29 | from ..attention import AttentionMixin, FeedForward |
29 | 30 | from ..attention_dispatch import dispatch_attention_fn |
30 | 31 | from ..attention_processor import Attention |
@@ -502,6 +503,18 @@ class QwenImageTransformer2DModel( |
502 | 503 | _no_split_modules = ["QwenImageTransformerBlock"] |
503 | 504 | _skip_layerwise_casting_patterns = ["pos_embed", "norm"] |
504 | 505 | _repeated_blocks = ["QwenImageTransformerBlock"] |
| 506 | + _cp_plan = { |
| 507 | + "": { |
| 508 | + "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 509 | + "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), |
| 510 | + "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), |
| 511 | + }, |
| 512 | + "pos_embed": { |
| 513 | + 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), |
| 514 | + 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True), |
| 515 | + }, |
| 516 | + "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), |
| 517 | + } |
505 | 518 |
|
506 | 519 | @register_to_config |
507 | 520 | def __init__( |
|
0 commit comments