55
66import torch
77import torch .nn as nn
8- import torch .nn .functional as F
9-
10-
11- def conv_nd (dims , * args , ** kwargs ):
12- """
13- Create a 1D, 2D, or 3D convolution module.
14- """
15- if dims == 1 :
16- return nn .Conv1d (* args , ** kwargs )
17- elif dims == 2 :
18- return nn .Conv2d (* args , ** kwargs )
19- elif dims == 3 :
20- return nn .Conv3d (* args , ** kwargs )
21- raise ValueError (f"unsupported dimensions: { dims } " )
22-
23-
24- # Copied from diffusers.models.controlnet.zero_module
25- def zero_module (module : nn .Module ) -> nn .Module :
26- for p in module .parameters ():
27- nn .init .zeros_ (p )
28- return module
8+ from safetensors .torch import load_file
299
3010
3111def get_clip_token_for_string (tokenizer , string ):
@@ -45,111 +25,47 @@ def get_clip_token_for_string(tokenizer, string):
4525 return tokens [0 , 1 ]
4626
4727
48- def get_bert_token_for_string (tokenizer , string ):
49- token = tokenizer (string )
50- assert (
51- torch .count_nonzero (token ) == 3
52- ), f"String '{ string } ' maps to more than a single token. Please use another string"
53- token = token [0 , 1 ]
54- return token
55-
56-
57- def get_clip_vision_emb (encoder , processor , img ):
58- _img = img .repeat (1 , 3 , 1 , 1 ) * 255
59- inputs = processor (images = _img , return_tensors = "pt" )
60- inputs ["pixel_values" ] = inputs ["pixel_values" ].to (img .device )
61- outputs = encoder (** inputs )
62- emb = outputs .image_embeds
63- return emb
64-
65-
6628def get_recog_emb (encoder , img_list ):
6729 _img_list = [(img .repeat (1 , 3 , 1 , 1 ) * 255 )[0 ] for img in img_list ]
6830 encoder .predictor .eval ()
6931 _ , preds_neck = encoder .pred_imglist (_img_list , show_debug = False )
7032 return preds_neck
7133
7234
73- def pad_H (x ):
74- _ , _ , H , W = x .shape
75- p_top = (W - H ) // 2
76- p_bot = W - H - p_top
77- return F .pad (x , (0 , 0 , p_top , p_bot ))
78-
79-
80- class EncodeNet (nn .Module ):
81- def __init__ (self , in_channels , out_channels ):
82- super (EncodeNet , self ).__init__ ()
83- chan = 16
84- n_layer = 4 # downsample
85-
86- self .conv1 = conv_nd (2 , in_channels , chan , 3 , padding = 1 )
87- self .conv_list = nn .ModuleList ([])
88- _c = chan
89- for i in range (n_layer ):
90- self .conv_list .append (conv_nd (2 , _c , _c * 2 , 3 , padding = 1 , stride = 2 ))
91- _c *= 2
92- self .conv2 = conv_nd (2 , _c , out_channels , 3 , padding = 1 )
93- self .avgpool = nn .AdaptiveAvgPool2d (1 )
94- self .act = nn .SiLU ()
95-
96- def forward (self , x ):
97- x = self .act (self .conv1 (x ))
98- for layer in self .conv_list :
99- x = self .act (layer (x ))
100- x = self .act (self .conv2 (x ))
101- x = self .avgpool (x )
102- x = x .view (x .size (0 ), - 1 )
103- return x
104-
105-
10635class EmbeddingManager (nn .Module ):
10736 def __init__ (
10837 self ,
10938 embedder ,
110- position_channels = 1 ,
11139 placeholder_string = "*" ,
112- add_pos = False ,
113- emb_type = "ocr" ,
11440 use_fp16 = False ,
11541 ):
11642 super ().__init__ ()
11743 get_token_for_string = partial (get_clip_token_for_string , embedder .tokenizer )
11844 token_dim = 768
11945 self .get_recog_emb = None
120- token_dim = 1280
12146 self .token_dim = token_dim
122- self .emb_type = emb_type
12347
124- self .add_pos = add_pos
125- if add_pos :
126- self .position_encoder = EncodeNet (position_channels , token_dim )
127- if emb_type == "ocr" :
128- self .proj = nn .Sequential (zero_module (nn .Linear (40 * 64 , token_dim )), nn .LayerNorm (token_dim ))
129- self .proj = self .proj .to (dtype = torch .float16 if use_fp16 else torch .float32 )
48+ self .proj = nn .Linear (40 * 64 , token_dim )
49+ self .proj .load_state_dict (load_file ("EmbeddingManager/embedding_manager.safetensors" , device = self .device ))
50+ if use_fp16 :
51+ self .proj = self .proj .to (dtype = torch .float16 )
13052
13153 self .placeholder_token = get_token_for_string (placeholder_string )
13254
55+ @torch .no_grad ()
13356 def encode_text (self , text_info ):
134- if self .get_recog_emb is None and self . emb_type == "ocr" :
57+ if self .get_recog_emb is None :
13558 self .get_recog_emb = partial (get_recog_emb , self .recog )
13659
13760 gline_list = []
138- pos_list = []
13961 for i in range (len (text_info ["n_lines" ])): # sample index in a batch
14062 n_lines = text_info ["n_lines" ][i ]
14163 for j in range (n_lines ): # line
14264 gline_list += [text_info ["gly_line" ][j ][i : i + 1 ]]
143- if self .add_pos :
144- pos_list += [text_info ["positions" ][j ][i : i + 1 ]]
14565
14666 if len (gline_list ) > 0 :
147- if self .emb_type == "ocr" :
148- recog_emb = self .get_recog_emb (gline_list )
149- enc_glyph = self .proj (recog_emb .reshape (recog_emb .shape [0 ], - 1 ))
150- if self .add_pos :
151- enc_pos = self .position_encoder (torch .cat (gline_list , dim = 0 ))
152- enc_glyph = enc_glyph + enc_pos
67+ recog_emb = self .get_recog_emb (gline_list )
68+ enc_glyph = self .proj (recog_emb .reshape (recog_emb .shape [0 ], - 1 ))
15369
15470 self .text_embs_all = []
15571 n_idx = 0
@@ -161,6 +77,7 @@ def encode_text(self, text_info):
16177 n_idx += 1
16278 self .text_embs_all += [text_embs ]
16379
80+ @torch .no_grad ()
16481 def forward (
16582 self ,
16683 tokenized_text ,
0 commit comments