1414from typing import Any , Dict , Optional , Tuple , Union
1515
1616import torch
17- import torch .nn .functional as F
1817from torch import nn
1918
2019from diffusers .configuration_utils import register_to_config
@@ -40,37 +39,67 @@ class AnyTextControlNetConditioningEmbedding(nn.Module):
4039
4140 def __init__ (
4241 self ,
43- conditioning_embedding_channels : int ,
44- conditioning_channels : int = 3 ,
45- block_out_channels : Tuple [ int , ...] = ( 16 , 32 , 96 , 256 ) ,
42+ glyph_channels = 1 ,
43+ position_channels = 1 ,
44+ model_channels = 320 ,
4645 ):
4746 super ().__init__ ()
4847
49- self .conv_in = nn .Conv2d (conditioning_channels , block_out_channels [0 ], kernel_size = 3 , padding = 1 )
50-
51- self .blocks = nn .ModuleList ([])
52-
53- for i in range (len (block_out_channels ) - 1 ):
54- channel_in = block_out_channels [i ]
55- channel_out = block_out_channels [i + 1 ]
56- self .blocks .append (nn .Conv2d (channel_in , channel_in , kernel_size = 3 , padding = 1 ))
57- self .blocks .append (nn .Conv2d (channel_in , channel_out , kernel_size = 3 , padding = 1 , stride = 2 ))
48+ self .glyph_block = nn .Sequential (
49+ nn .Conv2d (glyph_channels , 8 , 3 , padding = 1 ),
50+ nn .SiLU (),
51+ nn .Conv2d (8 , 8 , 3 , padding = 1 ),
52+ nn .SiLU (),
53+ nn .Conv2d (8 , 16 , 3 , padding = 1 , stride = 2 ),
54+ nn .SiLU (),
55+ nn .Conv2d (16 , 16 , 3 , padding = 1 ),
56+ nn .SiLU (),
57+ nn .Conv2d (16 , 32 , 3 , padding = 1 , stride = 2 ),
58+ nn .SiLU (),
59+ nn .Conv2d (32 , 32 , 3 , padding = 1 ),
60+ nn .SiLU (),
61+ nn .Conv2d (32 , 96 , 3 , padding = 1 , stride = 2 ),
62+ nn .SiLU (),
63+ nn .Conv2d (96 , 96 , 3 , padding = 1 ),
64+ nn .SiLU (),
65+ nn .Conv2d (96 , 256 , 3 , padding = 1 , stride = 2 ),
66+ nn .SiLU (),
67+ )
5868
59- self .conv_out = zero_module (
60- nn .Conv2d (block_out_channels [- 1 ], conditioning_embedding_channels , kernel_size = 3 , padding = 1 )
69+ self .position_block = nn .Sequential (
70+ nn .Conv2d (position_channels , 8 , 3 , padding = 1 ),
71+ nn .SiLU (),
72+ nn .Conv2d (8 , 8 , 3 , padding = 1 ),
73+ nn .SiLU (),
74+ nn .Conv2d (8 , 16 , 3 , padding = 1 , stride = 2 ),
75+ nn .SiLU (),
76+ nn .Conv2d (16 , 16 , 3 , padding = 1 ),
77+ nn .SiLU (),
78+ nn .Conv2d (16 , 32 , 3 , padding = 1 , stride = 2 ),
79+ nn .SiLU (),
80+ nn .Conv2d (32 , 32 , 3 , padding = 1 ),
81+ nn .SiLU (),
82+ nn .Conv2d (32 , 64 , 3 , padding = 1 , stride = 2 ),
83+ nn .SiLU (),
6184 )
6285
63- def forward (self , conditioning ):
64- embedding = self .conv_in (conditioning )
65- embedding = F .silu (embedding )
86+ self .fuse_block = nn .Conv2d (256 + 64 + 4 , model_channels , 3 , padding = 1 )
87+
88+ # self.glyph_block.load_state_dict(load_file("glyph_block.safetensors", device=str(self.device)))
89+ # self.position_block.load_state_dict(load_file("position_block.safetensors", device=str(self.device)))
90+ # self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device)))
6691
67- for block in self .blocks :
68- embedding = block (embedding )
69- embedding = F .silu (embedding )
92+ # if use_fp16:
93+ # self.glyph_block = self.glyph_block.to(dtype=torch.float16)
94+ # self.position_block = self.position_block.to(dtype=torch.float16)
95+ # self.fuse_block = self.fuse_block.to(dtype=torch.float16)
7096
71- embedding = self .conv_out (embedding )
97+ def forward (self , glyphs , positions , text_info ):
98+ glyph_embedding = self .glyph_block (glyphs )
99+ position_embedding = self .position_block (positions )
100+ guided_hint = self .fuse_block (torch .cat ([glyph_embedding , position_embedding , text_info ["masked_x" ]], dim = 1 ))
72101
73- return embedding
102+ return guided_hint
74103
75104
76105class AnyTextControlNetModel (ControlNetModel ):
0 commit comments