Skip to content

Commit f60a72b

Browse files
committed
Fix device issues
1 parent 8b43bc3 commit f60a72b

File tree

5 files changed

+12
-7
lines changed

5 files changed

+12
-7
lines changed

examples/research_projects/anytext/auxiliary_latent_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, glyph_channels=1, position_channels=1, model_channels=320, **
3939
super().__init__()
4040
self.font = ImageFont.truetype("Arial_Unicode.ttf", 60)
4141
self.use_fp16 = kwargs.get("use_fp16", False)
42-
self.device = kwargs.get("device", "cpu")
42+
self.device = kwargs.get("device", "cuda")
4343

4444
self.glyph_block = nn.Sequential(
4545
nn.Conv2d(glyph_channels, 8, 3, padding=1),

examples/research_projects/anytext/embedding_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
self.position_encoder = EncodeNet(position_channels, token_dim)
135135
if emb_type == "ocr":
136136
self.proj = nn.Sequential(zero_module(nn.Linear(40 * 64, token_dim)), nn.LayerNorm(token_dim))
137+
self.proj = self.proj.to(dtype=torch.float16 if kwargs.get("use_fp16", False) else torch.float32)
137138
if emb_type == "conv":
138139
self.glyph_encoder = EncodeNet(glyph_channels, token_dim)
139140

examples/research_projects/anytext/frozen_clip_embedder_t3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
):
2121
super().__init__()
2222
self.tokenizer = CLIPTokenizer.from_pretrained(version)
23-
self.transformer = CLIPTextModel.from_pretrained(version).to(device)
23+
self.transformer = CLIPTextModel.from_pretrained(version, torch_dtype=torch.float16).to(device)
2424
if use_vision:
2525
self.vit = CLIPVisionModelWithProjection.from_pretrained(version)
2626
self.processor = AutoProcessor.from_pretrained(version)

examples/research_projects/anytext/pipeline_anytext.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,7 @@ def __call__(
11721172
num_images_per_prompt=num_images_per_prompt,
11731173
np_hint=np_hint,
11741174
)
1175+
height, width = 512, 512
11751176
# elif isinstance(controlnet, MultiControlNetModel):
11761177
# images = []
11771178

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(self, use_fp16):
3030
"position_channels": 1,
3131
"add_pos": False,
3232
"placeholder_string": "*",
33+
"use_fp16": self.use_fp16,
3334
}
3435
self.embedding_manager = EmbeddingManager(self.frozen_CLIP_embedder_t3, **self.embedding_manager_config)
3536
# TODO: Understand the reason of param.requires_grad = True
@@ -41,8 +42,10 @@ def __init__(self, use_fp16):
4142
args["rec_image_shape"] = "3, 48, 320"
4243
args["rec_batch_num"] = 6
4344
args["rec_char_dict_path"] = "ppocr_keys_v1.txt"
44-
args["use_fp16"] = False
45-
self.cn_recognizer = TextRecognizer(args, self.text_predictor)
45+
args["use_fp16"] = True
46+
self.cn_recognizer = TextRecognizer(
47+
args, self.text_predictor.to(dtype=torch.float16 if use_fp16 else torch.float32)
48+
)
4649
for param in self.text_predictor.parameters():
4750
param.requires_grad = False
4851
self.embedding_manager.recog = self.cn_recognizer
@@ -149,9 +152,9 @@ def forward(
149152
glyphs = np.zeros((h * gly_scale, w * gly_scale, 1))
150153
gly_line = np.zeros((80, 512, 1))
151154
pos = pre_pos[i]
152-
text_info["glyphs"] += [self.arr2tensor(glyphs, len(prompt))]
153-
text_info["gly_line"] += [self.arr2tensor(gly_line, len(prompt))]
154-
text_info["positions"] += [self.arr2tensor(pos, len(prompt))]
155+
text_info["glyphs"] += [self.arr2tensor(glyphs, num_images_per_prompt)]
156+
text_info["gly_line"] += [self.arr2tensor(gly_line, num_images_per_prompt)]
157+
text_info["positions"] += [self.arr2tensor(pos, num_images_per_prompt)]
155158

156159
# hint = self.arr2tensor(np_hint, len(prompt))
157160

0 commit comments

Comments
 (0)