Skip to content

Commit da115bd

Browse files
committed
ensure backwards compat with optional args
1 parent 752f7a1 commit da115bd

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

comfy/sd.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,16 @@ def add_patches(self, patches, strength=1.0):
372372
def clip_layer(self, layer_idx):
373373
self.layer_idx = layer_idx
374374

375-
def tokenize(self, text):
376-
return self.tokenizer.tokenize_with_weights(text)
375+
def tokenize(self, text, return_word_ids=False):
376+
return self.tokenizer.tokenize_with_weights(text, return_word_ids)
377377

378-
def encode(self, tokens):
378+
def encode(self, text, from_tokens=False):
379379
if self.layer_idx is not None:
380380
self.cond_stage_model.clip_layer(self.layer_idx)
381+
if from_tokens:
382+
tokens = text
383+
else:
384+
tokens = self.tokenizer.tokenize_with_weights(text)
381385
try:
382386
self.patcher.patch_model()
383387
cond = self.cond_stage_model.encode_token_weights(tokens)

comfy/sd1_clip.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _try_get_embedding(self, embedding_name:str):
240240
return (embed, "")
241241

242242

243-
def tokenize_with_weights(self, text:str):
243+
def tokenize_with_weights(self, text:str, return_word_ids=False):
244244
'''
245245
Takes a prompt and converts it to a list of (token, weight, word id) elements.
246246
Tokens can both be integer tokens and pre computed CLIP tensors.
@@ -301,6 +301,10 @@ def tokenize_with_weights(self, text:str):
301301

302302
#add start and end tokens
303303
batched_tokens = [[(self.start_token, 1.0, 0)] + x + [(self.end_token, 1.0, 0)] for x in batched_tokens]
304+
305+
if not return_word_ids:
306+
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
307+
304308
return batched_tokens
305309

306310

nodes.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def INPUT_TYPES(s):
4444
CATEGORY = "conditioning"
4545

4646
def encode(self, clip, text):
47-
tokens = clip.tokenize(text)
48-
return ([[clip.encode(tokens), {}]], )
47+
return ([[clip.encode(text), {}]], )
4948

5049
class ConditioningCombine:
5150
@classmethod

0 commit comments

Comments
 (0)