@@ -39,9 +39,9 @@ class AnyTextControlNetConditioningEmbedding(nn.Module):
3939
4040 def __init__ (
4141 self ,
42+ conditioning_embedding_channels : int ,
4243 glyph_channels = 1 ,
4344 position_channels = 1 ,
44- model_channels = 320 ,
4545 ):
4646 super ().__init__ ()
4747
@@ -83,7 +83,7 @@ def __init__(
8383 nn .SiLU (),
8484 )
8585
86- self .fuse_block = nn .Conv2d (256 + 64 + 4 , model_channels , 3 , padding = 1 )
86+ self .fuse_block = nn .Conv2d (256 + 64 + 4 , conditioning_embedding_channels , 3 , padding = 1 )
8787
8888 # self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device)))
8989 # self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device)))
@@ -177,7 +177,7 @@ class conditioning with `class_embed_type` equal to `None`.
177177 def __init__ (
178178 self ,
179179 in_channels : int = 4 ,
180- conditioning_channels : int = 3 ,
180+ conditioning_channels : int = 1 ,
181181 flip_sin_to_cos : bool = True ,
182182 freq_shift : int = 0 ,
183183 down_block_types : Tuple [str , ...] = (
@@ -251,11 +251,12 @@ def __init__(
251251
252252 # control net conditioning embedding
253253 # TODO: what happens ControlNetModel's self.controlnet_cond_embedding's memory occupation?
254- self .controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding (
255- conditioning_embedding_channels = block_out_channels [0 ],
256- block_out_channels = conditioning_embedding_out_channels ,
257- conditioning_channels = conditioning_channels ,
258- )
254+ # self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
255+ # conditioning_embedding_channels=block_out_channels[0],
256+ # glyph_channels=conditioning_channels,
257+ # position_channels=conditioning_channels,
258+ # )
259+ self .controlnet_cond_embedding = None
259260
260261 def forward (
261262 self ,
0 commit comments