Skip to content

Commit d52e973

Browse files
committed
refactor: Simplify code for loading models and handling data types
1 parent f347ff2 commit d52e973

File tree

6 files changed

+33
-18
lines changed

6 files changed

+33
-18
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ def retrieve_latents(
2929

3030
class AuxiliaryLatentModule(nn.Module):
3131
def __init__(
32-
self, glyph_channels=1, position_channels=1, model_channels=320, vae=None, device="cpu", use_fp16=False
32+
self, font_path, glyph_channels=1, position_channels=1, model_channels=320, vae=None, device="cpu", use_fp16=False
3333
):
3434
super().__init__()
35-
self.font = ImageFont.truetype("font/Arial_Unicode.ttf", 60)
35+
self.font = ImageFont.truetype(font_path, 60)
3636
self.use_fp16 = use_fp16
3737
self.device = device
3838

@@ -79,12 +79,12 @@ def __init__(
7979
self.fuse_block = nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)
8080

8181
self.glyph_block.load_state_dict(
82-
load_file("AuxiliaryLatentModule/glyph_block.safetensors", device=self.device)
82+
load_file("glyph_block.safetensors", device=str(self.device))
8383
)
8484
self.position_block.load_state_dict(
85-
load_file("AuxiliaryLatentModule/position_block.safetensors", device=self.device)
85+
load_file("position_block.safetensors", device=str(self.device))
8686
)
87-
self.fuse_block.load_state_dict(load_file("AuxiliaryLatentModule/fuse_block.safetensors", device=self.device))
87+
self.fuse_block.load_state_dict(load_file("fuse_block.safetensors", device=str(self.device)))
8888

8989
if use_fp16:
9090
self.glyph_block = self.glyph_block.to(dtype=torch.float16)

examples/research_projects/anytext/embedding_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
self.token_dim = token_dim
4747

4848
self.proj = nn.Linear(40 * 64, token_dim)
49-
self.proj.load_state_dict(load_file("EmbeddingManager/embedding_manager.safetensors", device=self.device))
49+
self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
5050
if use_fp16:
5151
self.proj = self.proj.to(dtype=torch.float16)
5252

@@ -65,7 +65,7 @@ def encode_text(self, text_info):
6565

6666
if len(gline_list) > 0:
6767
recog_emb = self.get_recog_emb(gline_list)
68-
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1))
68+
enc_glyph = self.proj(recog_emb.reshape(recog_emb.shape[0], -1).to(self.proj.weight.device))
6969

7070
self.text_embs_all = []
7171
n_idx = 0

examples/research_projects/anytext/ocr_recog/RecModel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def load_3rd_state_dict(self, _3rd_name, _state):
3434
self.head.load_3rd_state_dict(_3rd_name, _state)
3535

3636
def forward(self, x):
37+
import torch
38+
x = x.to(torch.float32)
3739
x = self.backbone(x)
3840
x = self.neck(x)
3941
x = self.head(x)

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ class AnyTextPipeline(
208208

209209
def __init__(
210210
self,
211+
font_path: str,
211212
vae: AutoencoderKL,
212213
text_encoder: CLIPTextModel,
213214
tokenizer: CLIPTokenizer,
@@ -218,7 +219,6 @@ def __init__(
218219
feature_extractor: CLIPImageProcessor,
219220
image_encoder: CLIPVisionModelWithProjection = None,
220221
requires_safety_checker: bool = True,
221-
font_path: str = "font/Arial_Unicode.ttf",
222222
):
223223
super().__init__()
224224
self.text_embedding_module = TextEmbeddingModule(
@@ -257,13 +257,15 @@ def __init__(
257257
safety_checker=safety_checker,
258258
feature_extractor=feature_extractor,
259259
image_encoder=image_encoder,
260+
# text_embedding_module=text_embedding_module,
261+
# auxiliary_latent_module=auxiliary_latent_module,
260262
)
261263
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
262264
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
263265
self.control_image_processor = VaeImageProcessor(
264266
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
265267
)
266-
self.register_to_config(requires_safety_checker=requires_safety_checker)
268+
self.register_to_config(requires_safety_checker=requires_safety_checker, font_path=font_path)
267269

268270
def modify_prompt(self, prompt):
269271
prompt = prompt.replace("“", '"')

examples/research_projects/anytext/recognizer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ def create_predictor(model_dir=None, model_lang="ch", device="cpu", use_fp16=Fal
9898

9999
rec_model = RecModel(rec_config)
100100
if model_file_path is not None:
101-
rec_model.load_state_dict(load_file(model_file_path, device=device)).to(
102-
dtype=torch.float16 if use_fp16 else torch.float32
103-
)
101+
rec_model.load_state_dict(torch.load(model_file_path, map_location=device))
104102
return rec_model
105103

106104

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@ def __init__(self, font_path, use_fp16=False, device="cpu"):
2424
self.font = ImageFont.truetype(font_path, 60)
2525
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16)
2626
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16)
27-
# TODO: Understand the reason of param.requires_grad = True
28-
for param in self.embedding_manager.embedding_parameters():
29-
param.requires_grad = True
30-
rec_model_dir = "OCR/ppv3_rec.safetensors"
27+
# for param in self.embedding_manager.embedding_parameters():
28+
# param.requires_grad = True
29+
rec_model_dir = "OCR/ppv3_rec.pth"
3130
self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval()
3231
for param in self.text_predictor.parameters():
3332
param.requires_grad = False
@@ -36,8 +35,7 @@ def __init__(self, font_path, use_fp16=False, device="cpu"):
3635
args["rec_batch_num"] = 6
3736
args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt"
3837
args["use_fp16"] = self.use_fp16
39-
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
40-
self.embedding_manager.recog = self.cn_recognizer
38+
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
4139

4240
@torch.no_grad()
4341
def forward(
@@ -290,3 +288,18 @@ def draw_glyph2(self, font, text, polygon, vertAng=10, scale=1, width=512, heigh
290288
img.paste(rotated_layer, (x_offset, y_offset), rotated_layer)
291289
img = np.expand_dims(np.array(img.convert("1")), axis=2).astype(np.float64)
292290
return img
291+
292+
def insert_spaces(self, string, nSpace):
293+
if nSpace == 0:
294+
return string
295+
new_string = ""
296+
for char in string:
297+
new_string += char + " " * nSpace
298+
return new_string[:-nSpace]
299+
300+
def to(self, device):
301+
self.device = device
302+
self.frozen_CLIP_embedder_t3.to(device)
303+
self.embedding_manager.to(device)
304+
self.text_predictor.to(device)
305+
return self

0 commit comments

Comments
 (0)