@@ -605,12 +605,13 @@ def forward(
605605 controlnet_cond : List [torch .Tensor ],
606606 control_type : torch .Tensor ,
607607 control_type_idx : List [int ],
608- conditioning_scale : float = 1.0 ,
608+ conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
609609 class_labels : Optional [torch .Tensor ] = None ,
610610 timestep_cond : Optional [torch .Tensor ] = None ,
611611 attention_mask : Optional [torch .Tensor ] = None ,
612612 added_cond_kwargs : Optional [Dict [str , torch .Tensor ]] = None ,
613613 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
614+ from_multi : bool = False ,
614615 guess_mode : bool = False ,
615616 return_dict : bool = True ,
616617 ) -> Union [ControlNetOutput , Tuple [Tuple [torch .Tensor , ...], torch .Tensor ]]:
@@ -647,6 +648,8 @@ def forward(
647648 Additional conditions for the Stable Diffusion XL UNet.
648649 cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
649650 A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
651+ from_multi (`bool`, defaults to `False`):
652+ Use standard scaling when called from `MultiControlNetUnionModel`.
650653 guess_mode (`bool`, defaults to `False`):
651654 In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
652655 you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
@@ -658,6 +661,9 @@ def forward(
658661 If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
659662 returned where the first element is the sample tensor.
660663 """
664+ if isinstance (conditioning_scale , float ):
665+ conditioning_scale = [conditioning_scale ] * len (controlnet_cond )
666+
661667 # check channel order
662668 channel_order = self .config .controlnet_conditioning_channel_order
663669
@@ -742,12 +748,16 @@ def forward(
742748 inputs = []
743749 condition_list = []
744750
745- for cond , control_idx in zip (controlnet_cond , control_type_idx ):
751+ for cond , control_idx , scale in zip (controlnet_cond , control_type_idx , conditioning_scale ):
746752 condition = self .controlnet_cond_embedding (cond )
747753 feat_seq = torch .mean (condition , dim = (2 , 3 ))
748754 feat_seq = feat_seq + self .task_embedding [control_idx ]
749- inputs .append (feat_seq .unsqueeze (1 ))
750- condition_list .append (condition )
755+ if from_multi :
756+ inputs .append (feat_seq .unsqueeze (1 ))
757+ condition_list .append (condition )
758+ else :
759+ inputs .append (feat_seq .unsqueeze (1 ) * scale )
760+ condition_list .append (condition * scale )
751761
752762 condition = sample
753763 feat_seq = torch .mean (condition , dim = (2 , 3 ))
@@ -759,10 +769,13 @@ def forward(
759769 x = layer (x )
760770
761771 controlnet_cond_fuser = sample * 0.0
762- for idx , condition in enumerate (condition_list [:- 1 ]):
772+ for ( idx , condition ), scale in zip ( enumerate (condition_list [:- 1 ]), conditioning_scale ):
763773 alpha = self .spatial_ch_projs (x [:, idx ])
764774 alpha = alpha .unsqueeze (- 1 ).unsqueeze (- 1 )
765- controlnet_cond_fuser += condition + alpha
775+ if from_multi :
776+ controlnet_cond_fuser += condition + alpha
777+ else :
778+ controlnet_cond_fuser += condition + alpha * scale
766779
767780 sample = sample + controlnet_cond_fuser
768781
@@ -806,12 +819,13 @@ def forward(
806819 # 6. scaling
807820 if guess_mode and not self .config .global_pool_conditions :
808821 scales = torch .logspace (- 1 , 0 , len (down_block_res_samples ) + 1 , device = sample .device ) # 0.1 to 1.0
809- scales = scales * conditioning_scale
822+ if from_multi :
823+ scales = scales * conditioning_scale [0 ]
810824 down_block_res_samples = [sample * scale for sample , scale in zip (down_block_res_samples , scales )]
811825 mid_block_res_sample = mid_block_res_sample * scales [- 1 ] # last one
812- else :
813- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples ]
814- mid_block_res_sample = mid_block_res_sample * conditioning_scale
826+ elif from_multi :
827+ down_block_res_samples = [sample * conditioning_scale [ 0 ] for sample in down_block_res_samples ]
828+ mid_block_res_sample = mid_block_res_sample * conditioning_scale [ 0 ]
815829
816830 if self .config .global_pool_conditions :
817831 down_block_res_samples = [
0 commit comments