@@ -752,7 +752,7 @@ def forward(
752752 condition = self .controlnet_cond_embedding (cond )
753753 feat_seq = torch .mean (condition , dim = (2 , 3 ))
754754 feat_seq = feat_seq + self .task_embedding [control_idx ]
755- if from_multi :
755+ if from_multi or len ( control_type_idx ) == 1 :
756756 inputs .append (feat_seq .unsqueeze (1 ))
757757 condition_list .append (condition )
758758 else :
@@ -772,7 +772,7 @@ def forward(
772772 for (idx , condition ), scale in zip (enumerate (condition_list [:- 1 ]), conditioning_scale ):
773773 alpha = self .spatial_ch_projs (x [:, idx ])
774774 alpha = alpha .unsqueeze (- 1 ).unsqueeze (- 1 )
775- if from_multi :
775+ if from_multi or len ( control_type_idx ) == 1 :
776776 controlnet_cond_fuser += condition + alpha
777777 else :
778778 controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ def forward(
819819 # 6. scaling
820820 if guess_mode and not self .config .global_pool_conditions :
821821 scales = torch .logspace (- 1 , 0 , len (down_block_res_samples ) + 1 , device = sample .device ) # 0.1 to 1.0
822- if from_multi :
822+ if from_multi or len ( control_type_idx ) == 1 :
823823 scales = scales * conditioning_scale [0 ]
824824 down_block_res_samples = [sample * scale for sample , scale in zip (down_block_res_samples , scales )]
825825 mid_block_res_sample = mid_block_res_sample * scales [- 1 ] # last one
826- elif from_multi :
826+ elif from_multi or len ( control_type_idx ) == 1 :
827827 down_block_res_samples = [sample * conditioning_scale [0 ] for sample in down_block_res_samples ]
828828 mid_block_res_sample = mid_block_res_sample * conditioning_scale [0 ]
829829
0 commit comments