@@ -605,7 +605,7 @@ def forward(
605605 controlnet_cond : List [torch .Tensor ],
606606 control_type : torch .Tensor ,
607607 control_type_idx : List [int ],
608- conditioning_scale : float = 1.0 ,
608+ conditioning_scale : Union [ float , List [ float ]] = 1.0 ,
609609 class_labels : Optional [torch .Tensor ] = None ,
610610 timestep_cond : Optional [torch .Tensor ] = None ,
611611 attention_mask : Optional [torch .Tensor ] = None ,
@@ -658,6 +658,9 @@ def forward(
658658 If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
659659 returned where the first element is the sample tensor.
660660 """
661+ if isinstance (conditioning_scale , float ):
662+ conditioning_scale = [conditioning_scale ] * len (controlnet_cond )
663+
661664 # check channel order
662665 channel_order = self .config .controlnet_conditioning_channel_order
663666
@@ -742,12 +745,12 @@ def forward(
742745 inputs = []
743746 condition_list = []
744747
745- for cond , control_idx in zip (controlnet_cond , control_type_idx ):
748+ for cond , control_idx , scale in zip (controlnet_cond , control_type_idx , conditioning_scale ):
746749 condition = self .controlnet_cond_embedding (cond )
747750 feat_seq = torch .mean (condition , dim = (2 , 3 ))
748751 feat_seq = feat_seq + self .task_embedding [control_idx ]
749- inputs .append (feat_seq .unsqueeze (1 ))
750- condition_list .append (condition )
752+ inputs .append (feat_seq .unsqueeze (1 ) * scale )
753+ condition_list .append (condition * scale )
751754
752755 condition = sample
753756 feat_seq = torch .mean (condition , dim = (2 , 3 ))
@@ -759,10 +762,10 @@ def forward(
759762 x = layer (x )
760763
761764 controlnet_cond_fuser = sample * 0.0
762- for idx , condition in enumerate (condition_list [:- 1 ]):
765+ for ( idx , condition ), scale in zip ( enumerate (condition_list [:- 1 ]), conditioning_scale ):
763766 alpha = self .spatial_ch_projs (x [:, idx ])
764767 alpha = alpha .unsqueeze (- 1 ).unsqueeze (- 1 )
765- controlnet_cond_fuser += condition + alpha
768+ controlnet_cond_fuser += condition + alpha * scale
766769
767770 sample = sample + controlnet_cond_fuser
768771
@@ -806,12 +809,8 @@ def forward(
806809 # 6. scaling
807810 if guess_mode and not self .config .global_pool_conditions :
808811 scales = torch .logspace (- 1 , 0 , len (down_block_res_samples ) + 1 , device = sample .device ) # 0.1 to 1.0
809- scales = scales * conditioning_scale
810812 down_block_res_samples = [sample * scale for sample , scale in zip (down_block_res_samples , scales )]
811813 mid_block_res_sample = mid_block_res_sample * scales [- 1 ] # last one
812- else :
813- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples ]
814- mid_block_res_sample = mid_block_res_sample * conditioning_scale
815814
816815 if self .config .global_pool_conditions :
817816 down_block_res_samples = [
0 commit comments