Skip to content

Commit f713171

Browse files
committed
Up
1 parent 18d3f60 commit f713171

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@ def retrieve_latents(
3535

3636

3737
class AuxiliaryLatentModule(nn.Module):
38-
def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **kwargs):
38+
def __init__(
39+
self, glyph_channels=1, position_channels=1, model_channels=320, vae=None, device="cpu", use_fp16=False
40+
):
3941
super().__init__()
4042
self.font = ImageFont.truetype("font/Arial_Unicode.ttf", 60)
41-
self.use_fp16 = kwargs.get("use_fp16", False)
42-
self.device = kwargs.get("device", "cpu")
43+
self.use_fp16 = use_fp16
44+
self.device = device
4345

4446
self.glyph_block = nn.Sequential(
4547
nn.Conv2d(glyph_channels, 8, 3, padding=1),
@@ -79,8 +81,7 @@ def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **
7981
nn.SiLU(),
8082
)
8183

82-
self.vae = kwargs.get("vae")
83-
self.vae.eval()
84+
self.vae = vae.eval()
8485

8586
self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1))
8687

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ def __init__(self, font_path, use_fp16=False, device="cpu"):
2828
for param in self.embedding_manager.embedding_parameters():
2929
param.requires_grad = True
3030
rec_model_dir = "OCR/ppv3_rec.safetensors"
31-
self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16)
31+
self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval()
32+
for param in self.text_predictor.parameters():
33+
param.requires_grad = False
3234
args = {}
3335
args["rec_image_shape"] = "3, 48, 320"
3436
args["rec_batch_num"] = 6
3537
args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt"
3638
args["use_fp16"] = self.use_fp16
37-
self.cn_recognizer = TextRecognizer(args, self.text_predictor, device=self.device, use_fp16=self.use_fp16)
38-
for param in self.text_predictor.parameters():
39-
param.requires_grad = False
39+
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
4040
self.embedding_manager.recog = self.cn_recognizer
4141

4242
@torch.no_grad()

0 commit comments

Comments
 (0)