Skip to content

Commit 8c628eb

Browse files
authored
Update transformer_qwenimage.py
1 parent e62804f commit 8c628eb

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ def forward(
552552
txt_seq_lens: Optional[List[int]] = None,
553553
guidance: torch.Tensor = None, # TODO: this should probably be removed
554554
attention_kwargs: Optional[Dict[str, Any]] = None,
555+
controlnet_block_samples = None,
555556
return_dict: bool = True,
556557
) -> Union[torch.Tensor, Transformer2DModelOutput]:
557558
"""
@@ -631,6 +632,12 @@ def forward(
631632
joint_attention_kwargs=attention_kwargs,
632633
)
633634

635+
# controlnet residual
636+
if controlnet_block_samples is not None:
637+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
638+
interval_control = int(np.ceil(interval_control))
639+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
640+
634641
# Use only the image part (hidden_states) from the dual-stream blocks
635642
hidden_states = self.norm_out(hidden_states, temb)
636643
output = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)