@@ -605,12 +605,13 @@ def forward(
605605        controlnet_cond : List [torch .Tensor ],
606606        control_type : torch .Tensor ,
607607        control_type_idx : List [int ],
608-         conditioning_scale : float  =  1.0 ,
608+         conditioning_scale : Union [ float ,  List [ float ]]  =  1.0 ,
609609        class_labels : Optional [torch .Tensor ] =  None ,
610610        timestep_cond : Optional [torch .Tensor ] =  None ,
611611        attention_mask : Optional [torch .Tensor ] =  None ,
612612        added_cond_kwargs : Optional [Dict [str , torch .Tensor ]] =  None ,
613613        cross_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
614+         from_multi : bool  =  False ,
614615        guess_mode : bool  =  False ,
615616        return_dict : bool  =  True ,
616617    ) ->  Union [ControlNetOutput , Tuple [Tuple [torch .Tensor , ...], torch .Tensor ]]:
@@ -647,6 +648,8 @@ def forward(
647648                Additional conditions for the Stable Diffusion XL UNet. 
648649            cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): 
649650                A kwargs dictionary that if specified is passed along to the `AttnProcessor`. 
651+             from_multi (`bool`, defaults to `False`): 
652+                 Use standard scaling when called from `MultiControlNetUnionModel`. 
650653            guess_mode (`bool`, defaults to `False`): 
651654                In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if 
652655                you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. 
@@ -658,6 +661,9 @@ def forward(
658661                If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is 
659662                returned where the first element is the sample tensor. 
660663        """ 
664+         if  isinstance (conditioning_scale , float ):
665+             conditioning_scale  =  [conditioning_scale ] *  len (controlnet_cond )
666+ 
661667        # check channel order 
662668        channel_order  =  self .config .controlnet_conditioning_channel_order 
663669
@@ -742,12 +748,16 @@ def forward(
742748        inputs  =  []
743749        condition_list  =  []
744750
745-         for  cond , control_idx   in  zip (controlnet_cond , control_type_idx ):
751+         for  cond , control_idx ,  scale   in  zip (controlnet_cond , control_type_idx ,  conditioning_scale ):
746752            condition  =  self .controlnet_cond_embedding (cond )
747753            feat_seq  =  torch .mean (condition , dim = (2 , 3 ))
748754            feat_seq  =  feat_seq  +  self .task_embedding [control_idx ]
749-             inputs .append (feat_seq .unsqueeze (1 ))
750-             condition_list .append (condition )
755+             if  from_multi :
756+                 inputs .append (feat_seq .unsqueeze (1 ))
757+                 condition_list .append (condition )
758+             else :
759+                 inputs .append (feat_seq .unsqueeze (1 ) *  scale )
760+                 condition_list .append (condition  *  scale )
751761
752762        condition  =  sample 
753763        feat_seq  =  torch .mean (condition , dim = (2 , 3 ))
@@ -759,10 +769,13 @@ def forward(
759769            x  =  layer (x )
760770
761771        controlnet_cond_fuser  =  sample  *  0.0 
762-         for  idx , condition   in  enumerate (condition_list [:- 1 ]):
772+         for  ( idx , condition ),  scale   in  zip ( enumerate (condition_list [:- 1 ]),  conditioning_scale ):
763773            alpha  =  self .spatial_ch_projs (x [:, idx ])
764774            alpha  =  alpha .unsqueeze (- 1 ).unsqueeze (- 1 )
765-             controlnet_cond_fuser  +=  condition  +  alpha 
775+             if  from_multi :
776+                 controlnet_cond_fuser  +=  condition  +  alpha 
777+             else :
778+                 controlnet_cond_fuser  +=  condition  +  alpha  *  scale 
766779
767780        sample  =  sample  +  controlnet_cond_fuser 
768781
@@ -806,12 +819,13 @@ def forward(
806819        # 6. scaling 
807820        if  guess_mode  and  not  self .config .global_pool_conditions :
808821            scales  =  torch .logspace (- 1 , 0 , len (down_block_res_samples ) +  1 , device = sample .device )  # 0.1 to 1.0 
809-             scales  =  scales  *  conditioning_scale 
822+             if  from_multi :
823+                 scales  =  scales  *  conditioning_scale [0 ]
810824            down_block_res_samples  =  [sample  *  scale  for  sample , scale  in  zip (down_block_res_samples , scales )]
811825            mid_block_res_sample  =  mid_block_res_sample  *  scales [- 1 ]  # last one 
812-         else :
813-             down_block_res_samples  =  [sample  *  conditioning_scale  for  sample  in  down_block_res_samples ]
814-             mid_block_res_sample  =  mid_block_res_sample  *  conditioning_scale 
826+         elif   from_multi :
827+             down_block_res_samples  =  [sample  *  conditioning_scale [ 0 ]  for  sample  in  down_block_res_samples ]
828+             mid_block_res_sample  =  mid_block_res_sample  *  conditioning_scale [ 0 ] 
815829
816830        if  self .config .global_pool_conditions :
817831            down_block_res_samples  =  [
0 commit comments