@@ -194,6 +194,19 @@ def __init__(
194194        super ().__init__ ()
195195        if  isinstance (controlnet , (list , tuple )):
196196            controlnet  =  SD3MultiControlNetModel (controlnet )
197+         if  isinstance (controlnet , SD3MultiControlNetModel ):
198+             for  controlnet_model  in  controlnet .nets :
199+                 # for SD3.5 8b controlnet, it shares the pos_embed with the transformer 
200+                 if  (
201+                     hasattr (controlnet_model .config , "use_pos_embed" )
202+                     and  controlnet_model .config .use_pos_embed  is  False 
203+                 ):
204+                     pos_embed  =  controlnet_model ._get_pos_embed_from_transformer (transformer )
205+                     controlnet_model .pos_embed  =  pos_embed .to (controlnet_model .dtype ).to (controlnet_model .device )
206+         elif  isinstance (controlnet , SD3ControlNetModel ):
207+             if  hasattr (controlnet .config , "use_pos_embed" ) and  controlnet .config .use_pos_embed  is  False :
208+                 pos_embed  =  controlnet ._get_pos_embed_from_transformer (transformer )
209+                 controlnet .pos_embed  =  pos_embed .to (controlnet .dtype ).to (controlnet .device )
197210
198211        self .register_modules (
199212            vae = vae ,
@@ -1042,15 +1055,9 @@ def __call__(
10421055                        controlnet_cond_scale  =  controlnet_cond_scale [0 ]
10431056                    cond_scale  =  controlnet_cond_scale  *  controlnet_keep [i ]
10441057
1045-                 if  controlnet_config .use_pos_embed  is  False :
1046-                     # sd35 (offical) 8b controlnet 
1047-                     controlnet_model_input  =  self .transformer .pos_embed (latent_model_input )
1048-                 else :
1049-                     controlnet_model_input  =  latent_model_input 
1050- 
10511058                # controlnet(s) inference 
10521059                control_block_samples  =  self .controlnet (
1053-                     hidden_states = controlnet_model_input ,
1060+                     hidden_states = latent_model_input ,
10541061                    timestep = timestep ,
10551062                    encoder_hidden_states = controlnet_encoder_hidden_states ,
10561063                    pooled_projections = controlnet_pooled_projections ,
0 commit comments