File tree Expand file tree Collapse file tree 2 files changed +7
-2
lines changed Expand file tree Collapse file tree 2 files changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -402,6 +402,7 @@ def forward(
402402        controlnet_block_samples = None ,
403403        controlnet_single_block_samples = None ,
404404        return_dict : bool  =  True ,
405+         controlnet_blocks_repeat : bool  =  False ,
405406    ) ->  Union [torch .FloatTensor , Transformer2DModelOutput ]:
406407        """ 
407408        The [`FluxTransformer2DModel`] forward method. 
@@ -509,8 +510,8 @@ def custom_forward(*inputs):
509510                interval_control  =  len (self .transformer_blocks ) /  len (controlnet_block_samples )
510511                interval_control  =  int (np .ceil (interval_control ))
511512                # For Xlabs ControlNet. 
512-                 if  len ( controlnet_block_samples )  ==   2 :
513-                     hidden_states  =  hidden_states  +  controlnet_block_samples [index_block  %  2 ]
513+                 if  controlnet_blocks_repeat :
514+                     hidden_states  =  hidden_states  +  controlnet_block_samples [index_block  %  len ( controlnet_block_samples ) ]
514515                else :
515516                    hidden_states  =  hidden_states  +  controlnet_block_samples [index_block  //  interval_control ]
516517
Original file line number Diff line number Diff line change @@ -739,6 +739,7 @@ def __call__(
739739        )
740740
741741        # 3. Prepare control image 
742+         controlnet_blocks_repeat  =  False 
742743        num_channels_latents  =  self .transformer .config .in_channels  //  4 
743744        if  isinstance (self .controlnet , FluxControlNetModel ):
744745            control_image  =  self .prepare_image (
@@ -766,6 +767,8 @@ def __call__(
766767                    height_control_image ,
767768                    width_control_image ,
768769                )
770+             else :
771+                 controlnet_blocks_repeat  =  True 
769772
770773            # Here we ensure that `control_mode` has the same length as the control_image. 
771774            if  control_mode  is  not None :
@@ -926,6 +929,7 @@ def __call__(
926929                    img_ids = latent_image_ids ,
927930                    joint_attention_kwargs = self .joint_attention_kwargs ,
928931                    return_dict = False ,
932+                     controlnet_blocks_repeat = controlnet_blocks_repeat ,
929933                )[0 ]
930934
931935                # compute the previous noisy sample x_t -> x_t-1 
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments