@@ -52,7 +52,7 @@ def forward(self, x):
5252
5353
5454class ControlNetFlux (Flux ):
55- def __init__ (self , latent_input = False , num_union_modes = 0 , mistoline = False , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
55+ def __init__ (self , latent_input = False , num_union_modes = 0 , mistoline = False , control_latent_channels = None , image_model = None , dtype = None , device = None , operations = None , ** kwargs ):
5656 super ().__init__ (final_layer = False , dtype = dtype , device = device , operations = operations , ** kwargs )
5757
5858 self .main_model_double = 19
@@ -80,7 +80,12 @@ def __init__(self, latent_input=False, num_union_modes=0, mistoline=False, image
8080
8181 self .gradient_checkpointing = False
8282 self .latent_input = latent_input
83- self .pos_embed_input = operations .Linear (self .in_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
83+ if control_latent_channels is None :
84+ control_latent_channels = self .in_channels
85+ else :
86+ control_latent_channels *= 2 * 2 #patch size
87+
88+ self .pos_embed_input = operations .Linear (control_latent_channels , self .hidden_size , bias = True , dtype = dtype , device = device )
8489 if not self .latent_input :
8590 if self .mistoline :
8691 self .input_cond_block = MistolineCondDownsamplBlock (dtype = dtype , device = device , operations = operations )
0 commit comments