Skip to content

Commit ae5a707

Browse files
committed
support qwen
1 parent a820bfd commit ae5a707

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
2626
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
2727
from ...utils.torch_utils import maybe_allow_in_graph
28+
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
2829
from ..attention import AttentionMixin, FeedForward
2930
from ..attention_dispatch import dispatch_attention_fn
3031
from ..attention_processor import Attention
@@ -502,6 +503,18 @@ class QwenImageTransformer2DModel(
502503
_no_split_modules = ["QwenImageTransformerBlock"]
503504
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
504505
_repeated_blocks = ["QwenImageTransformerBlock"]
506+
_cp_plan = {
507+
"": {
508+
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
509+
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
510+
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
511+
},
512+
"pos_embed": {
513+
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
514+
1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
515+
},
516+
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
517+
}
505518

506519
@register_to_config
507520
def __init__(

0 commit comments

Comments
 (0)