|
4 | 4 | import numpy as np |
5 | 5 | import torch |
6 | 6 | from PIL import ImageFont |
| 7 | +from safetensors.torch import load_file |
7 | 8 | from torch import nn |
8 | 9 |
|
9 | 10 | from diffusers.utils import logging |
@@ -34,18 +35,12 @@ def retrieve_latents( |
34 | 35 |
|
35 | 36 |
|
36 | 37 | class AuxiliaryLatentModule(nn.Module): |
37 | | - def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels=320, **kwargs): |
| 38 | + def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **kwargs): |
38 | 39 | super().__init__() |
39 | | - self.font = ImageFont.truetype("/home/cosmos/Documents/gits/AnyText/font/Arial_Unicode.ttf", 60) |
| 40 | + self.font = ImageFont.truetype("Arial_Unicode.ttf", 60) |
40 | 41 | self.use_fp16 = kwargs.get("use_fp16", False) |
41 | 42 | self.device = kwargs.get("device", "cpu") |
42 | | - self.model_channels = model_channels |
43 | | - time_embed_dim = model_channels * 4 |
44 | | - self.time_embed = nn.Sequential( |
45 | | - nn.Linear(model_channels, time_embed_dim), |
46 | | - nn.SiLU(), |
47 | | - nn.Linear(time_embed_dim, time_embed_dim), |
48 | | - ) |
| 43 | + |
49 | 44 | self.glyph_block = nn.Sequential( |
50 | 45 | nn.Conv2d(glyph_channels, 8, 3, padding=1), |
51 | 46 | nn.SiLU(), |
@@ -83,20 +78,21 @@ def __init__(self, dims=2, glyph_channels=1, position_channels=1, model_channels |
83 | 78 | nn.Conv2d(32, 64, 3, padding=1, stride=2), |
84 | 79 | nn.SiLU(), |
85 | 80 | ) |
86 | | - self.time_embed = self.time_embed.to(device="cuda", dtype=torch.float16) |
| 81 | + self.glyph_block.load_state_dict(load_file("glyph_block.safetensors")) |
| 82 | + self.position_block.load_state_dict(load_file("position_block.safetensors")) |
87 | 83 | self.glyph_block = self.glyph_block.to(device="cuda", dtype=torch.float16) |
88 | 84 | self.position_block = self.position_block.to(device="cuda", dtype=torch.float16) |
89 | 85 |
|
90 | 86 | self.vae = kwargs.get("vae") |
91 | 87 | self.vae.eval() |
92 | 88 |
|
93 | 89 | self.fuse_block = zero_module(nn.Conv2d(256 + 64 + 4, model_channels, 3, padding=1)) |
| 90 | + self.fuse_block.load_state_dict(load_file("fuse_block.safetensors")) |
94 | 91 | self.fuse_block = self.fuse_block.to(device="cuda", dtype=torch.float16) |
95 | 92 |
|
96 | 93 | @torch.no_grad() |
97 | 94 | def forward( |
98 | 95 | self, |
99 | | - context, |
100 | 96 | text_info, |
101 | 97 | mode, |
102 | 98 | draw_pos, |
|
0 commit comments