Skip to content

Commit cffa036

Browse files
committed
Up
1 parent 73d8144 commit cffa036

File tree

3 files changed

+10
-15
lines changed

3 files changed

+10
-15
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import torch
66
from PIL import ImageFont
7+
from safetensors.torch import load_file
78
from torch import nn
89

910
from diffusers.utils import logging
@@ -34,18 +35,12 @@ def retrieve_latents(
3435

3536

3637
class AuxiliaryLatentModule(nn.Module):
37-
def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels=320, **kwargs):
38+
def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **kwargs):
3839
super().__init__()
39-
self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60)
40+
self.font = ImageFont.truetype("Arial_Unicode.ttf", 60)
4041
self.use_fp16 = kwargs.get("use_fp16", False)
4142
self.device = kwargs.get("device", "cpu")
42-
self.model_channels = model_channels
43-
time_embed_dim = model_channels * 4
44-
self.time_embed = nn.Sequential(
45-
nn.Linear(model_channels, time_embed_dim),
46-
nn.SiLU(),
47-
nn.Linear(time_embed_dim, time_embed_dim),
48-
)
43+
4944
self.glyph_block = nn.Sequential(
5045
nn.Conv2d(glyph_channels, 8, 3, padding=1),
5146
nn.SiLU(),
@@ -83,20 +78,21 @@ def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels
8378
nn.Conv2d(32, 64, 3, padding=1, stride=2),
8479
nn.SiLU(),
8580
)
86-
self.time_embed = self.time_embed.to(device="cuda", dtype=torch.float16)
81+
self.glyph_block.load_state_dict(load_file("glyph_block.safetensors"))
82+
self.position_block.load_state_dict(load_file("position_block.safetensors"))
8783
self.glyph_block = self.glyph_block.to(device="cuda", dtype=torch.float16)
8884
self.position_block = self.position_block.to(device="cuda", dtype=torch.float16)
8985

9086
self.vae = kwargs.get("vae")
9187
self.vae.eval()
9288

9389
self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1))
90+
self.fuse_block.load_state_dict(load_file("fuse_block.safetensors"))
9491
self.fuse_block = self.fuse_block.to(device="cuda", dtype=torch.float16)
9592

9693
@torch.no_grad()
9794
def forward(
9895
self,
99-
context,
10096
text_info,
10197
mode,
10298
draw_pos,

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,6 @@ def __call__(
11651165
# )
11661166
# height, width = image.shape[-2:]
11671167
guided_hint = self.auxiliary_latent_module(
1168-
context=prompt_embeds[1],
11691168
text_info=text_info,
11701169
mode=mode,
11711170
draw_pos=draw_pos,

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, use_fp16):
2121
self.use_fp16 = use_fp16
2222
self.device = "cuda" if torch.cuda.is_available() else "cpu"
2323
# TODO: Learn if the recommended font file is free to use
24-
self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60)
24+
self.font = ImageFont.truetype("Arial_Unicode.ttf", 60)
2525
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device)
2626
self.embedding_manager_config = {
2727
"valid": True,
@@ -35,12 +35,12 @@ def __init__(self, use_fp16):
3535
# TODO: Understand the reason of param.requires_grad = True
3636
for param in self.embedding_manager.embedding_parameters():
3737
param.requires_grad = True
38-
rec_model_dir = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppv3_rec.pth"
38+
rec_model_dir = "ppv3_rec.pth"
3939
self.text_predictor = create_predictor(rec_model_dir).eval()
4040
args = {}
4141
args["rec_image_shape"] = "3, 48, 320"
4242
args["rec_batch_num"] = 6
43-
args["rec_char_dict_path"] = "/home/cosmos/Documents/gits/AnyText/ocr_weights/ppocr_keys_v1.txt"
43+
args["rec_char_dict_path"] = "ppocr_keys_v1.txt"
4444
args["use_fp16"] = False
4545
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
4646
for param in self.text_predictor.parameters():

0 commit comments

Comments
 (0)