Skip to content

Commit a54832d

Browse files
committed
from_multi
1 parent a1f1c70 commit a54832d

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ def forward(
611611
attention_mask: Optional[torch.Tensor] = None,
612612
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
613613
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
614+
from_multi: bool = False,
614615
guess_mode: bool = False,
615616
return_dict: bool = True,
616617
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
@@ -647,6 +648,8 @@ def forward(
647648
Additional conditions for the Stable Diffusion XL UNet.
648649
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
649650
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
651+
from_multi (`bool`, defaults to `False`):
652+
Use standard scaling when called from `MultiControlNetUnionModel`.
650653
guess_mode (`bool`, defaults to `False`):
651654
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
652655
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
@@ -749,8 +752,12 @@ def forward(
749752
condition = self.controlnet_cond_embedding(cond)
750753
feat_seq = torch.mean(condition, dim=(2, 3))
751754
feat_seq = feat_seq + self.task_embedding[control_idx]
752-
inputs.append(feat_seq.unsqueeze(1) * scale)
753-
condition_list.append(condition * scale)
755+
if from_multi:
756+
inputs.append(feat_seq.unsqueeze(1))
757+
condition_list.append(condition)
758+
else:
759+
inputs.append(feat_seq.unsqueeze(1) * scale)
760+
condition_list.append(condition * scale)
754761

755762
condition = sample
756763
feat_seq = torch.mean(condition, dim=(2, 3))
@@ -765,7 +772,10 @@ def forward(
765772
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
766773
alpha = self.spatial_ch_projs(x[:, idx])
767774
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
768-
controlnet_cond_fuser += condition + alpha * scale
775+
if from_multi:
776+
controlnet_cond_fuser += condition + alpha
777+
else:
778+
controlnet_cond_fuser += condition + alpha * scale
769779

770780
sample = sample + controlnet_cond_fuser
771781

@@ -809,8 +819,13 @@ def forward(
809819
# 6. scaling
810820
if guess_mode and not self.config.global_pool_conditions:
811821
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
822+
if from_multi:
823+
scales = scales * conditioning_scale[0]
812824
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
813825
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
826+
elif from_multi:
827+
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
828+
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
814829

815830
if self.config.global_pool_conditions:
816831
down_block_res_samples = [

src/diffusers/models/controlnets/multicontrolnet_union.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,12 @@ def forward(
4747
guess_mode: bool = False,
4848
return_dict: bool = True,
4949
) -> Union[ControlNetOutput, Tuple]:
50+
down_block_res_samples, mid_block_res_sample = None, None
5051
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
5152
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
5253
):
54+
if scale == 0.0:
55+
continue
5356
down_samples, mid_sample = controlnet(
5457
sample=sample,
5558
timestep=timestep,
@@ -68,7 +71,7 @@ def forward(
6871
)
6972

7073
# merge samples
71-
if i == 0:
74+
if down_block_res_samples is None and mid_block_res_sample is None:
7275
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
7376
else:
7477
down_block_res_samples = [

0 commit comments

Comments
 (0)