Skip to content

Commit c4db96a

Browse files
committed
Move custom blocks from AuxiliaryLatentModule to AnyTextControlNetConditioningEmbedding
1 parent 21c0c35 commit c4db96a

File tree

2 files changed

+53
-80
lines changed

2 files changed

+53
-80
lines changed

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from frozen_clip_embedder_t3 import FrozenCLIPEmbedderT3
3434
from PIL import Image, ImageDraw, ImageFont
3535
from recognizer import TextRecognizer, create_predictor
36-
from safetensors.torch import load_file
3736
from torch import nn
3837
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
3938

@@ -410,9 +409,6 @@ class AuxiliaryLatentModule(nn.Module):
410409
def __init__(
411410
self,
412411
font_path,
413-
glyph_channels=1,
414-
position_channels=1,
415-
model_channels=320,
416412
vae=None,
417413
device="cpu",
418414
use_fp16=False,
@@ -422,57 +418,8 @@ def __init__(
422418
self.use_fp16 = use_fp16
423419
self.device = device
424420

425-
self.glyph_block = nn.Sequential(
426-
nn.Conv2d(glyph_channels, 8, 3, padding=1),
427-
nn.SiLU(),
428-
nn.Conv2d(8, 8, 3, padding=1),
429-
nn.SiLU(),
430-
nn.Conv2d(8, 16, 3, padding=1, stride=2),
431-
nn.SiLU(),
432-
nn.Conv2d(16, 16, 3, padding=1),
433-
nn.SiLU(),
434-
nn.Conv2d(16, 32, 3, padding=1, stride=2),
435-
nn.SiLU(),
436-
nn.Conv2d(32, 32, 3, padding=1),
437-
nn.SiLU(),
438-
nn.Conv2d(32, 96, 3, padding=1, stride=2),
439-
nn.SiLU(),
440-
nn.Conv2d(96, 96, 3, padding=1),
441-
nn.SiLU(),
442-
nn.Conv2d(96, 256, 3, padding=1, stride=2),
443-
nn.SiLU(),
444-
)
445-
446-
self.position_block = nn.Sequential(
447-
nn.Conv2d(position_channels, 8, 3, padding=1),
448-
nn.SiLU(),
449-
nn.Conv2d(8, 8, 3, padding=1),
450-
nn.SiLU(),
451-
nn.Conv2d(8, 16, 3, padding=1, stride=2),
452-
nn.SiLU(),
453-
nn.Conv2d(16, 16, 3, padding=1),
454-
nn.SiLU(),
455-
nn.Conv2d(16, 32, 3, padding=1, stride=2),
456-
nn.SiLU(),
457-
nn.Conv2d(32, 32, 3, padding=1),
458-
nn.SiLU(),
459-
nn.Conv2d(32, 64, 3, padding=1, stride=2),
460-
nn.SiLU(),
461-
)
462-
463421
self.vae = vae.eval() if vae is not None else None
464422

465-
self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)
466-
467-
self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device)))
468-
self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device)))
469-
self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device)))
470-
471-
if use_fp16:
472-
self.glyph_block = self.glyph_block.to(dtype=torch.float16)
473-
self.position_block = self.position_block.to(dtype=torch.float16)
474-
self.fuse_block = self.fuse_block.to(dtype=torch.float16)
475-
476423
@torch.no_grad()
477424
def forward(
478425
self,
@@ -518,11 +465,8 @@ def forward(
518465

519466
glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
520467
positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
521-
enc_glyph = self.glyph_block(glyphs)
522-
enc_pos = self.position_block(positions)
523-
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"]], dim=1))
524468

525-
return guided_hint
469+
return glyphs, positions, text_info
526470

527471
def check_channels(self, image):
528472
channels = image.shape[2] if len(image.shape) == 3 else 1

examples/research_projects/anytext/text_controlnet.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from typing import Any, Dict, Optional, Tuple, Union
1515

1616
import torch
17-
import torch.nn.functional as F
1817
from torch import nn
1918

2019
from diffusers.configuration_utils import register_to_config
@@ -40,37 +39,67 @@ class AnyTextControlNetConditioningEmbedding(nn.Module):
4039

4140
def __init__(
4241
self,
43-
conditioning_embedding_channels: int,
44-
conditioning_channels: int = 3,
45-
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
42+
glyph_channels=1,
43+
position_channels=1,
44+
model_channels=320,
4645
):
4746
super().__init__()
4847

49-
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
50-
51-
self.blocks = nn.ModuleList([])
52-
53-
for i in range(len(block_out_channels) - 1):
54-
channel_in = block_out_channels[i]
55-
channel_out = block_out_channels[i + 1]
56-
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
57-
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
48+
self.glyph_block = nn.Sequential(
49+
nn.Conv2d(glyph_channels, 8, 3, padding=1),
50+
nn.SiLU(),
51+
nn.Conv2d(8, 8, 3, padding=1),
52+
nn.SiLU(),
53+
nn.Conv2d(8, 16, 3, padding=1, stride=2),
54+
nn.SiLU(),
55+
nn.Conv2d(16, 16, 3, padding=1),
56+
nn.SiLU(),
57+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
58+
nn.SiLU(),
59+
nn.Conv2d(32, 32, 3, padding=1),
60+
nn.SiLU(),
61+
nn.Conv2d(32, 96, 3, padding=1, stride=2),
62+
nn.SiLU(),
63+
nn.Conv2d(96, 96, 3, padding=1),
64+
nn.SiLU(),
65+
nn.Conv2d(96, 256, 3, padding=1, stride=2),
66+
nn.SiLU(),
67+
)
5868

59-
self.conv_out = zero_module(
60-
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
69+
self.position_block = nn.Sequential(
70+
nn.Conv2d(position_channels, 8, 3, padding=1),
71+
nn.SiLU(),
72+
nn.Conv2d(8, 8, 3, padding=1),
73+
nn.SiLU(),
74+
nn.Conv2d(8, 16, 3, padding=1, stride=2),
75+
nn.SiLU(),
76+
nn.Conv2d(16, 16, 3, padding=1),
77+
nn.SiLU(),
78+
nn.Conv2d(16, 32, 3, padding=1, stride=2),
79+
nn.SiLU(),
80+
nn.Conv2d(32, 32, 3, padding=1),
81+
nn.SiLU(),
82+
nn.Conv2d(32, 64, 3, padding=1, stride=2),
83+
nn.SiLU(),
6184
)
6285

63-
def forward(self, conditioning):
64-
embedding = self.conv_in(conditioning)
65-
embedding = F.silu(embedding)
86+
self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)
87+
88+
# self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device)))
89+
# self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device)))
90+
# self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device)))
6691

67-
for block in self.blocks:
68-
embedding = block(embedding)
69-
embedding = F.silu(embedding)
92+
# if use_fp16:
93+
# self.glyph_block = self.glyph_block.to(dtype=torch.float16)
94+
# self.position_block = self.position_block.to(dtype=torch.float16)
95+
# self.fuse_block = self.fuse_block.to(dtype=torch.float16)
7096

71-
embedding = self.conv_out(embedding)
97+
def forward(self, glyphs, positions, text_info):
98+
glyph_embedding = self.glyph_block(glyphs)
99+
position_embedding = self.position_block(positions)
100+
guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
72101

73-
return embedding
102+
return guided_hint
74103

75104

76105
class AnyTextControlNetModel(ControlNetModel):

0 commit comments

Comments
 (0)