@@ -611,6 +611,7 @@ def forward(
611611 attention_mask : Optional [torch .Tensor ] = None ,
612612 added_cond_kwargs : Optional [Dict [str , torch .Tensor ]] = None ,
613613 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
614+ from_multi : bool = False ,
614615 guess_mode : bool = False ,
615616 return_dict : bool = True ,
616617 ) -> Union [ControlNetOutput , Tuple [Tuple [torch .Tensor , ...], torch .Tensor ]]:
@@ -647,6 +648,8 @@ def forward(
647648 Additional conditions for the Stable Diffusion XL UNet.
648649 cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
649650 A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
651+ from_multi (`bool`, defaults to `False`):
652+ Use standard scaling when called from `MultiControlNetUnionModel`.
650653 guess_mode (`bool`, defaults to `False`):
651654 In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
652655 you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
@@ -749,8 +752,12 @@ def forward(
749752 condition = self .controlnet_cond_embedding (cond )
750753 feat_seq = torch .mean (condition , dim = (2 , 3 ))
751754 feat_seq = feat_seq + self .task_embedding [control_idx ]
752- inputs .append (feat_seq .unsqueeze (1 ) * scale )
753- condition_list .append (condition * scale )
755+ if from_multi :
756+ inputs .append (feat_seq .unsqueeze (1 ))
757+ condition_list .append (condition )
758+ else :
759+ inputs .append (feat_seq .unsqueeze (1 ) * scale )
760+ condition_list .append (condition * scale )
754761
755762 condition = sample
756763 feat_seq = torch .mean (condition , dim = (2 , 3 ))
@@ -765,7 +772,10 @@ def forward(
765772 for (idx , condition ), scale in zip (enumerate (condition_list [:- 1 ]), conditioning_scale ):
766773 alpha = self .spatial_ch_projs (x [:, idx ])
767774 alpha = alpha .unsqueeze (- 1 ).unsqueeze (- 1 )
768- controlnet_cond_fuser += condition + alpha * scale
775+ if from_multi :
776+ controlnet_cond_fuser += condition + alpha
777+ else :
778+ controlnet_cond_fuser += condition + alpha * scale
769779
770780 sample = sample + controlnet_cond_fuser
771781
@@ -809,8 +819,13 @@ def forward(
809819 # 6. scaling
810820 if guess_mode and not self .config .global_pool_conditions :
811821 scales = torch .logspace (- 1 , 0 , len (down_block_res_samples ) + 1 , device = sample .device ) # 0.1 to 1.0
822+ if from_multi :
823+ scales = scales * conditioning_scale [0 ]
812824 down_block_res_samples = [sample * scale for sample , scale in zip (down_block_res_samples , scales )]
813825 mid_block_res_sample = mid_block_res_sample * scales [- 1 ] # last one
826+ elif from_multi :
827+ down_block_res_samples = [sample * conditioning_scale [0 ] for sample in down_block_res_samples ]
828+ mid_block_res_sample = mid_block_res_sample * conditioning_scale [0 ]
814829
815830 if self .config .global_pool_conditions :
816831 down_block_res_samples = [
0 commit comments