Skip to content

Commit 9657980

Browse files
committed
[FIX] Ensure embeddings use correct device in AnyTextControlNetModel
1 parent b8ca0d6 commit 9657980

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/research_projects/anytext/anytext_controlnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def __init__(
9595
# self.fuse_block = self.fuse_block.to(dtype=torch.float16)
9696

9797
def forward(self, glyphs, positions, text_info):
98-
glyph_embedding = self.glyph_block(glyphs)
99-
position_embedding = self.position_block(positions)
98+
glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))
99+
position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))
100100
guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
101101

102102
return guided_hint
@@ -390,7 +390,7 @@ def forward(
390390
# 2. pre-process
391391
sample = self.conv_in(sample)
392392

393-
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
393+
controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)
394394
sample = sample + controlnet_cond
395395

396396
# 3. down

0 commit comments

Comments
 (0)