diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index d2b3d8a733f3..c10bf3ed4f7b 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -676,8 +676,8 @@ class Flux2Transformer2DModel( "": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), - "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False), + "img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), + "txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), }