Skip to content

Commit 1e098d6

Browse files
Don't add template to qwen2.5vl when template is in prompt. (#10043)
Make the hunyuan image refiner template_end 36.
1 parent cd66d72 commit 1e098d6

File tree

2 files changed

+36
-20
lines changed

2 files changed

+36
-20
lines changed

comfy/text_encoders/hunyuan_image.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,13 @@ def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
6363
self.byt5_small = None
6464

6565
def encode_token_weights(self, token_weight_pairs):
66-
cond, p, extra = super().encode_token_weights(token_weight_pairs)
66+
tok_pairs = token_weight_pairs["qwen25_7b"][0]
67+
template_end = -1
68+
if tok_pairs[0][0] == 27:
69+
if len(tok_pairs) > 36: # refiner prompt uses a fixed 36 template_end
70+
template_end = 36
71+
72+
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end)
6773
if self.byt5_small is not None and "byt5" in token_weight_pairs:
6874
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
6975
extra["conditioning_byt5small"] = out[0]

comfy/text_encoders/qwen_image.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,22 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
1818
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
1919

2020
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
21-
if llama_template is None:
22-
if len(images) > 0:
23-
llama_text = self.llama_template_images.format(text)
24-
else:
25-
llama_text = self.llama_template.format(text)
21+
skip_template = False
22+
if text.startswith('<|im_start|>'):
23+
skip_template = True
24+
if text.startswith('<|start_header_id|>'):
25+
skip_template = True
26+
27+
if skip_template:
28+
llama_text = text
2629
else:
27-
llama_text = llama_template.format(text)
30+
if llama_template is None:
31+
if len(images) > 0:
32+
llama_text = self.llama_template_images.format(text)
33+
else:
34+
llama_text = self.llama_template.format(text)
35+
else:
36+
llama_text = llama_template.format(text)
2837
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
2938
key_name = next(iter(tokens))
3039
embed_count = 0
@@ -47,22 +56,23 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
4756
def __init__(self, device="cpu", dtype=None, model_options={}):
4857
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
4958

50-
def encode_token_weights(self, token_weight_pairs):
59+
def encode_token_weights(self, token_weight_pairs, template_end=-1):
5160
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
5261
tok_pairs = token_weight_pairs["qwen25_7b"][0]
5362
count_im_start = 0
54-
for i, v in enumerate(tok_pairs):
55-
elem = v[0]
56-
if not torch.is_tensor(elem):
57-
if isinstance(elem, numbers.Integral):
58-
if elem == 151644 and count_im_start < 2:
59-
template_end = i
60-
count_im_start += 1
61-
62-
if out.shape[1] > (template_end + 3):
63-
if tok_pairs[template_end + 1][0] == 872:
64-
if tok_pairs[template_end + 2][0] == 198:
65-
template_end += 3
63+
if template_end == -1:
64+
for i, v in enumerate(tok_pairs):
65+
elem = v[0]
66+
if not torch.is_tensor(elem):
67+
if isinstance(elem, numbers.Integral):
68+
if elem == 151644 and count_im_start < 2:
69+
template_end = i
70+
count_im_start += 1
71+
72+
if out.shape[1] > (template_end + 3):
73+
if tok_pairs[template_end + 1][0] == 872:
74+
if tok_pairs[template_end + 2][0] == 198:
75+
template_end += 3
6676

6777
out = out[:, template_end:]
6878

0 commit comments

Comments
 (0)