Skip to content

Commit 2b6f08b

Browse files
committed
Refactor AnyTextControlNet to use configurable conditioning embedding channels
1 parent 0c94143 commit 2b6f08b

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

examples/research_projects/anytext/anytext_controlnet.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ class AnyTextControlNetConditioningEmbedding(nn.Module):
3939

4040
def __init__(
4141
self,
42+
conditioning_embedding_channels: int,
4243
glyph_channels=1,
4344
position_channels=1,
44-
model_channels=320,
4545
):
4646
super().__init__()
4747

@@ -83,7 +83,7 @@ def __init__(
8383
nn.SiLU(),
8484
)
8585

86-
self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)
86+
self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)
8787

8888
# self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device)))
8989
# self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device)))
@@ -177,7 +177,7 @@ class conditioning with `class_embed_type` equal to `None`.
177177
def __init__(
178178
self,
179179
in_channels: int = 4,
180-
conditioning_channels: int = 3,
180+
conditioning_channels: int = 1,
181181
flip_sin_to_cos: bool = True,
182182
freq_shift: int = 0,
183183
down_block_types: Tuple[str, ...] = (
@@ -251,11 +251,12 @@ def __init__(
251251

252252
# control net conditioning embedding
253253
# TODO: what happens ControlNetModel's self.controlnet_cond_embedding's memory occupation?
254-
self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
255-
conditioning_embedding_channels=block_out_channels[0],
256-
block_out_channels=conditioning_embedding_out_channels,
257-
conditioning_channels=conditioning_channels,
258-
)
254+
# self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
255+
# conditioning_embedding_channels=block_out_channels[0],
256+
# glyph_channels=conditioning_channels,
257+
# position_channels=conditioning_channels,
258+
# )
259+
self.controlnet_cond_embedding = None
259260

260261
def forward(
261262
self,

0 commit comments

Comments
 (0)