@@ -67,7 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
6767 return key_img , value_img
6868
6969
70- # Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
70+ # modified from diffusers.models.transformers.transformer_wan.WanAttnProcessor
7171class WanAttnProcessor :
7272 _attention_backend = None
7373 _parallel_config = None
@@ -137,7 +137,8 @@ def apply_rotary_emb(
137137 dropout_p = 0.0 ,
138138 is_causal = False ,
139139 backend = self ._attention_backend ,
140- parallel_config = self ._parallel_config ,
140+ # Reference: https://github.com/huggingface/diffusers/pull/12660
141+ parallel_config = None ,
141142 )
142143 hidden_states_img = hidden_states_img .flatten (2 , 3 )
143144 hidden_states_img = hidden_states_img .type_as (query )
@@ -150,7 +151,8 @@ def apply_rotary_emb(
150151 dropout_p = 0.0 ,
151152 is_causal = False ,
152153 backend = self ._attention_backend ,
153- parallel_config = self ._parallel_config ,
154+ # Reference: https://github.com/huggingface/diffusers/pull/12660
155+ parallel_config = (self ._parallel_config if encoder_hidden_states is None else None ),
154156 )
155157 hidden_states = hidden_states .flatten (2 , 3 )
156158 hidden_states = hidden_states .type_as (query )
@@ -568,9 +570,11 @@ class ChronoEditTransformer3DModel(
568570 "blocks.0" : {
569571 "hidden_states" : ContextParallelInput (split_dim = 1 , expected_dims = 3 , split_output = False ),
570572 },
571- "blocks.*" : {
572- "encoder_hidden_states" : ContextParallelInput (split_dim = 1 , expected_dims = 3 , split_output = False ),
573- },
573+ # Reference: https://github.com/huggingface/diffusers/pull/12660
574+ # We need to disable the splitting of encoder_hidden_states because
575+ # the image_encoder consistently generates 257 tokens for image_embed. This causes
576+ # the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
577+ # after concatenation—to be indivisible by the number of devices in the CP.
574578 "proj_out" : ContextParallelOutput (gather_dim = 1 , expected_dims = 3 ),
575579 }
576580
0 commit comments