Skip to content

Commit 936c2ff

Browse files
committed
Simplify
1 parent bbfe8f2 commit 936c2ff

File tree

3 files changed

+3
-41
lines changed

3 files changed

+3
-41
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
# text -> glyph render -> glyph l_g -> glyph block ->
2-
# +> fuse layer
3-
# position l_p -> position block ->
4-
5-
import math
61
from typing import Optional
72

83
import cv2
94
import numpy as np
105
import torch
11-
from einops import repeat
126
from PIL import ImageFont
137
from torch import nn
148

@@ -146,39 +140,12 @@ def forward(
146140

147141
glyphs = torch.cat(text_info["glyphs"], dim=1).sum(dim=1, keepdim=True)
148142
positions = torch.cat(text_info["positions"], dim=1).sum(dim=1, keepdim=True)
149-
t_emb = self.timestep_embedding(torch.tensor([1000], device="cuda"), self.model_channels, repeat_only=False)
150-
if self.use_fp16:
151-
t_emb = t_emb.half()
152-
emb = self.time_embed(t_emb)
153-
print(glyphs.shape, emb.shape, positions.shape, context.shape)
154-
enc_glyph = self.glyph_block(glyphs.cuda(), emb, context)
155-
enc_pos = self.position_block(positions.cuda(), emb, context)
143+
enc_glyph = self.glyph_block(glyphs.cuda())
144+
enc_pos = self.position_block(positions.cuda())
156145
guided_hint = self.fuse_block(torch.cat([enc_glyph, enc_pos, text_info["masked_x"].cuda()], dim=1))
157146

158147
return guided_hint
159148

160-
def timestep_embedding(self, timesteps, dim, max_period=10000, repeat_only=False):
161-
"""
162-
Create sinusoidal timestep embeddings.
163-
:param timesteps: a 1-D Tensor of N indices, one per batch element.
164-
These may be fractional.
165-
:param dim: the dimension of the output.
166-
:param max_period: controls the minimum frequency of the embeddings.
167-
:return: an [N x dim] Tensor of positional embeddings.
168-
"""
169-
if not repeat_only:
170-
half = dim // 2
171-
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
172-
device=timesteps.device
173-
)
174-
args = timesteps[:, None].float() * freqs[None]
175-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
176-
if dim % 2:
177-
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
178-
else:
179-
embedding = repeat(timesteps, "b -> b d", d=dim)
180-
return embedding
181-
182149
def encode_first_stage(self, masked_img):
183150
return retrieve_latents(self.vae.encode(masked_img)) * self.vae.config.scaling_factor
184151

examples/research_projects/anytext/text_controlnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
class AnyTextControlNetModel(ControlNetModel):
3030
"""
31-
A PromptDiffusionControlNet model.
31+
A AnyTextControlNetModel model.
3232
3333
Args:
3434
in_channels (`int`, defaults to 4):

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
# text -> glyph render -> glyph lines -> OCR -> linear ->
2-
# +> Token Replacement -> FrozenCLIPEmbedderT3
3-
# text -> tokenizer ->
4-
5-
61
import cv2
72
import numpy as np
83
import torch

0 commit comments

Comments
 (0)