|
1 | | -# text -> glyph render -> glyph l_g -> glyph block -> |
2 | | -# +> fuse layer |
3 | | -# position l_p -> position block -> |
4 | | - |
5 | | -import math |
6 | 1 | from typing import Optional |
7 | 2 |
|
8 | 3 | import cv2 |
9 | 4 | import numpy as np |
10 | 5 | import torch |
11 | | -from einops import repeat |
12 | 6 | from PIL import ImageFont |
13 | 7 | from torch import nn |
14 | 8 |
|
@@ -146,39 +140,12 @@ def forward( |
146 | 140 |
|
147 | 141 | glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True) |
148 | 142 | positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True) |
149 | | - t_emb = self.timestep_embedding(torch.tensor([1000], device="cuda"), self.model_channels, repeat_only=False) |
150 | | - if self.use_fp16: |
151 | | - t_emb = t_emb.half() |
152 | | - emb = self.time_embed(t_emb) |
153 | | - print(glyphs.shape, emb.shape, positions.shape, context.shape) |
154 | | - enc_glyph = self.glyph_block(glyphs.cuda(), emb, context) |
155 | | - enc_pos = self.position_block(positions.cuda(), emb, context) |
| 143 | + enc_glyph = self.glyph_block(glyphs.cuda()) |
| 144 | + enc_pos = self.position_block(positions.cuda()) |
156 | 145 | guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"].cuda()], dim=1)) |
157 | 146 |
|
158 | 147 | return guided_hint |
159 | 148 |
|
160 | | - def timestep_embedding(self, timesteps, dim, max_period=10000, repeat_only=False): |
161 | | - """ |
162 | | - Create sinusoidal timestep embeddings. |
163 | | - :param timesteps: a 1-D Tensor of N indices, one per batch element. |
164 | | - These may be fractional. |
165 | | - :param dim: the dimension of the output. |
166 | | - :param max_period: controls the minimum frequency of the embeddings. |
167 | | - :return: an [N x dim] Tensor of positional embeddings. |
168 | | - """ |
169 | | - if not repeat_only: |
170 | | - half = dim // 2 |
171 | | - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( |
172 | | - device=timesteps.device |
173 | | - ) |
174 | | - args = timesteps[:, None].float() * freqs[None] |
175 | | - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
176 | | - if dim % 2: |
177 | | - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
178 | | - else: |
179 | | - embedding = repeat(timesteps, "b -> b d", d=dim) |
180 | | - return embedding |
181 | | - |
182 | 149 | def encode_first_stage(self, masked_img): |
183 | 150 | return retrieve_latents(self.vae.encode(masked_img)) * self.vae.config.scaling_factor |
184 | 151 |
|
|
0 commit comments