Skip to content

Commit 4267c84

Browse files
committed
refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule
1 parent 020074a commit 4267c84

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

examples/research_projects/anytext/frozen_clip_embedder_t3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,8 @@ def split_chunks(self, input_ids, chunk_size=75):
207207
remaining_group_pad = torch.cat((id_start, remaining_group, padding, id_end), dim=1)
208208
tokens_list.append(remaining_group_pad)
209209
return tokens_list
210+
211+
def to(self, *args, **kwargs):
212+
self.transformer = self.transformer.to(*args, **kwargs)
213+
self.device = self.transformer.device
214+
return self

examples/research_projects/anytext/text_embedding_module.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,9 +297,9 @@ def insert_spaces(self, string, nSpace):
297297
new_string += char + " " * nSpace
298298
return new_string[:-nSpace]
299299

300-
def to(self, device):
301-
self.device = device
302-
self.frozen_CLIP_embedder_t3.to(device)
303-
self.embedding_manager.to(device)
304-
self.text_predictor.to(device)
300+
def to(self, *args, **kwargs):
301+
self.frozen_CLIP_embedder_t3 = self.frozen_CLIP_embedder_t3.to(*args, **kwargs)
302+
self.embedding_manager = self.embedding_manager.to(*args, **kwargs)
303+
self.text_predictor = self.text_predictor.to(*args, **kwargs)
304+
self.device = self.frozen_CLIP_embedder_t3.device
305305
return self

0 commit comments

Comments
 (0)