@@ -55,7 +55,7 @@ def __init__(
5555        guidance_embeds : bool  =  False ,
5656        axes_dims_rope : List [int ] =  [16 , 56 , 56 ],
5757        num_mode : int  =  None ,
58-         is_xlabs_controlnet :  bool  =  False ,
58+         conditioning_embedding_channels :  int  =  None ,
5959    ):
6060        super ().__init__ ()
6161        self .out_channels  =  in_channels 
@@ -107,13 +107,14 @@ def __init__(
107107        if  self .union :
108108            self .controlnet_mode_embedder  =  nn .Embedding (num_mode , self .inner_dim )
109109
110-         if  self . is_xlabs_controlnet :
110+         if  conditioning_embedding_channels   is   not   None :
111111            self .input_hint_block  =  ControlNetConditioningEmbedding (
112-                 conditioning_embedding_channels = 16 ,
112+                 conditioning_embedding_channels = conditioning_embedding_channels ,
113113                block_out_channels = (16 ,16 ,16 ,16 )
114114            )
115115            self .controlnet_x_embedder  =  torch .nn .Linear (in_channels , self .inner_dim )
116116        else :
117+             self .input_hint_block  =  None 
117118            self .controlnet_x_embedder  =  zero_module (torch .nn .Linear (in_channels , self .inner_dim ))
118119
119120        self .gradient_checkpointing  =  False 
@@ -277,7 +278,7 @@ def forward(
277278                )
278279        hidden_states  =  self .x_embedder (hidden_states )
279280
280-         if  self .is_xlabs_controlnet :
281+         if  self .input_hint_block   is   not   None :
281282            controlnet_cond  =  self .input_hint_block (controlnet_cond )
282283            batch_size , channels , height_pw , width_pw  =  controlnet_cond .shape 
283284            height  =  height_pw  //  self .config .patch_size 
0 commit comments