|
67 | 67 | unscale_lora_layers, |
68 | 68 | ) |
69 | 69 | from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor |
| 70 | +from diffusers.configuration_utils import register_to_config, ConfigMixin |
| 71 | +from diffusers.models.modeling_utils import ModelMixin |
70 | 72 |
|
71 | 73 |
|
72 | 74 | checker = BasicTokenizer() |
|
88 | 90 | >>> # load control net and stable diffusion v1-5 |
89 | 91 | >>> text_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16, |
90 | 92 | ... variant="fp16",) |
91 | | - >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, |
| 93 | + >>> pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet, |
92 | 94 | ... torch_dtype=torch.float16, variant="fp16", |
93 | 95 | ... ).to("cuda") |
94 | 96 |
|
@@ -150,7 +152,7 @@ def __init__( |
150 | 152 | self.token_dim = token_dim |
151 | 153 |
|
152 | 154 | self.proj = nn.Linear(40 * 64, token_dim) |
153 | | - self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device))) |
| 155 | + # self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device))) |
154 | 156 | if use_fp16: |
155 | 157 | self.proj = self.proj.to(dtype=torch.float16) |
156 | 158 |
|
@@ -449,20 +451,19 @@ def get_ctcloss(self, preds, gt_text, weight): |
449 | 451 |
|
450 | 452 |
|
451 | 453 | class TextEmbeddingModule(nn.Module): |
| 454 | + # @register_to_config |
452 | 455 | def __init__(self, font_path, use_fp16=False, device="cpu"): |
453 | 456 | super().__init__() |
454 | | - self.use_fp16 = use_fp16 |
455 | | - self.device = device |
456 | 457 | # TODO: Learn if the recommended font file is free to use |
457 | 458 | self.font = ImageFont.truetype(font_path, 60) |
458 | | - self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16) |
459 | | - self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16) |
460 | | - rec_model_dir = "OCR/ppv3_rec.pth" |
461 | | - self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval() |
| 459 | + self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16) |
| 460 | + self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16) |
| 461 | + rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth" |
| 462 | + self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval() |
462 | 463 | args = {} |
463 | 464 | args["rec_image_shape"] = "3, 48, 320" |
464 | 465 | args["rec_batch_num"] = 6 |
465 | | - args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt" |
| 466 | + args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt" |
466 | 467 | args["use_fp16"] = self.use_fp16 |
467 | 468 | self.embedding_manager.recog = TextRecognizer(args, self.text_predictor) |
468 | 469 |
|
@@ -843,9 +844,6 @@ def insert_spaces(self, string, nSpace): |
843 | 844 |
|
844 | 845 | def to(self, device): |
845 | 846 | self.device = device |
846 | | - self.glyph_block = self.glyph_block.to(device) |
847 | | - self.position_block = self.position_block.to(device) |
848 | | - self.fuse_block = self.fuse_block.to(device) |
849 | 847 | self.vae = self.vae.to(device) |
850 | 848 | return self |
851 | 849 |
|
@@ -1011,8 +1009,8 @@ def __init__( |
1011 | 1009 | safety_checker=safety_checker, |
1012 | 1010 | feature_extractor=feature_extractor, |
1013 | 1011 | image_encoder=image_encoder, |
1014 | | - # text_embedding_module=text_embedding_module, |
1015 | | - # auxiliary_latent_module=auxiliary_latent_module, |
| 1012 | + # text_embedding_module=self.text_embedding_module, |
| 1013 | + # auxiliary_latent_module=self.auxiliary_latent_module, |
1016 | 1014 | ) |
1017 | 1015 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
1018 | 1016 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) |
|
0 commit comments