Skip to content

Commit b8ca0d6

Browse files
committed
up
1 parent 2ffb80b commit b8ca0d6

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
unscale_lora_layers,
6868
)
6969
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
70+
from diffusers.configuration_utils import register_to_config, ConfigMixin
71+
from diffusers.models.modeling_utils import ModelMixin
7072

7173

7274
checker = BasicTokenizer()
@@ -88,7 +90,7 @@
8890
>>> # load control net and stable diffusion v1-5
8991
>>> text_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
9092
... variant="fp16",)
91-
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet,
93+
>>> pipe = AnyTextPipeline.from_pretrained("tolgacangoz/anytext", controlnet=text_controlnet,
9294
... torch_dtype=torch.float16, variant="fp16",
9395
... ).to("cuda")
9496
@@ -150,7 +152,7 @@ def __init__(
150152
self.token_dim = token_dim
151153

152154
self.proj = nn.Linear(40 * 64, token_dim)
153-
self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
155+
# self.proj.load_state_dict(load_file("proj.safetensors", device=str(embedder.device)))
154156
if use_fp16:
155157
self.proj = self.proj.to(dtype=torch.float16)
156158

@@ -449,20 +451,19 @@ def get_ctcloss(self, preds, gt_text, weight):
449451

450452

451453
class TextEmbeddingModule(nn.Module):
454+
# @register_to_config
452455
def __init__(self, font_path, use_fp16=False, device="cpu"):
453456
super().__init__()
454-
self.use_fp16 = use_fp16
455-
self.device = device
456457
# TODO: Learn if the recommended font file is free to use
457458
self.font = ImageFont.truetype(font_path, 60)
458-
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=self.device, use_fp16=self.use_fp16)
459-
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=self.use_fp16)
460-
rec_model_dir = "OCR/ppv3_rec.pth"
461-
self.text_predictor = create_predictor(rec_model_dir, device=self.device, use_fp16=self.use_fp16).eval()
459+
self.frozen_CLIP_embedder_t3 = FrozenCLIPEmbedderT3(device=device, use_fp16=use_fp16)
460+
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, use_fp16=use_fp16)
461+
rec_model_dir = "./text_embedding_module/OCR/ppv3_rec.pth"
462+
self.text_predictor = create_predictor(rec_model_dir, device=device, use_fp16=use_fp16).eval()
462463
args = {}
463464
args["rec_image_shape"] = "3, 48, 320"
464465
args["rec_batch_num"] = 6
465-
args["rec_char_dict_path"] = "OCR/ppocr_keys_v1.txt"
466+
args["rec_char_dict_path"] = "./text_embedding_module/OCR/ppocr_keys_v1.txt"
466467
args["use_fp16"] = self.use_fp16
467468
self.embedding_manager.recog = TextRecognizer(args, self.text_predictor)
468469

@@ -843,9 +844,6 @@ def insert_spaces(self, string, nSpace):
843844

844845
def to(self, device):
845846
self.device = device
846-
self.glyph_block = self.glyph_block.to(device)
847-
self.position_block = self.position_block.to(device)
848-
self.fuse_block = self.fuse_block.to(device)
849847
self.vae = self.vae.to(device)
850848
return self
851849

@@ -1011,8 +1009,8 @@ def __init__(
10111009
safety_checker=safety_checker,
10121010
feature_extractor=feature_extractor,
10131011
image_encoder=image_encoder,
1014-
# text_embedding_module=text_embedding_module,
1015-
# auxiliary_latent_module=auxiliary_latent_module,
1012+
# text_embedding_module=self.text_embedding_module,
1013+
# auxiliary_latent_module=self.auxiliary_latent_module,
10161014
)
10171015
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
10181016
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)

0 commit comments

Comments
 (0)