@@ -752,7 +752,7 @@ def forward(
752752            condition  =  self .controlnet_cond_embedding (cond )
753753            feat_seq  =  torch .mean (condition , dim = (2 , 3 ))
754754            feat_seq  =  feat_seq  +  self .task_embedding [control_idx ]
755-             if  from_multi :
755+             if  from_multi   or   len ( control_type_idx )  ==   1 :
756756                inputs .append (feat_seq .unsqueeze (1 ))
757757                condition_list .append (condition )
758758            else :
@@ -772,7 +772,7 @@ def forward(
772772        for  (idx , condition ), scale  in  zip (enumerate (condition_list [:- 1 ]), conditioning_scale ):
773773            alpha  =  self .spatial_ch_projs (x [:, idx ])
774774            alpha  =  alpha .unsqueeze (- 1 ).unsqueeze (- 1 )
775-             if  from_multi :
775+             if  from_multi   or   len ( control_type_idx )  ==   1 :
776776                controlnet_cond_fuser  +=  condition  +  alpha 
777777            else :
778778                controlnet_cond_fuser  +=  condition  +  alpha  *  scale 
@@ -819,11 +819,11 @@ def forward(
819819        # 6. scaling 
820820        if  guess_mode  and  not  self .config .global_pool_conditions :
821821            scales  =  torch .logspace (- 1 , 0 , len (down_block_res_samples ) +  1 , device = sample .device )  # 0.1 to 1.0 
822-             if  from_multi :
822+             if  from_multi   or   len ( control_type_idx )  ==   1 :
823823                scales  =  scales  *  conditioning_scale [0 ]
824824            down_block_res_samples  =  [sample  *  scale  for  sample , scale  in  zip (down_block_res_samples , scales )]
825825            mid_block_res_sample  =  mid_block_res_sample  *  scales [- 1 ]  # last one 
826-         elif  from_multi :
826+         elif  from_multi   or   len ( control_type_idx )  ==   1 :
827827            down_block_res_samples  =  [sample  *  conditioning_scale [0 ] for  sample  in  down_block_res_samples ]
828828            mid_block_res_sample  =  mid_block_res_sample  *  conditioning_scale [0 ]
829829
0 commit comments