@@ -18,13 +18,22 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):
18
18
self .llama_template_images = "<|im_start|>system\n Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n <|im_start|>user\n <|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n <|im_start|>assistant\n "
19
19
20
20
def tokenize_with_weights (self , text , return_word_ids = False , llama_template = None , images = [], ** kwargs ):
21
- if llama_template is None :
22
- if len (images ) > 0 :
23
- llama_text = self .llama_template_images .format (text )
24
- else :
25
- llama_text = self .llama_template .format (text )
21
+ skip_template = False
22
+ if text .startswith ('<|im_start|>' ):
23
+ skip_template = True
24
+ if text .startswith ('<|start_header_id|>' ):
25
+ skip_template = True
26
+
27
+ if skip_template :
28
+ llama_text = text
26
29
else :
27
- llama_text = llama_template .format (text )
30
+ if llama_template is None :
31
+ if len (images ) > 0 :
32
+ llama_text = self .llama_template_images .format (text )
33
+ else :
34
+ llama_text = self .llama_template .format (text )
35
+ else :
36
+ llama_text = llama_template .format (text )
28
37
tokens = super ().tokenize_with_weights (llama_text , return_word_ids = return_word_ids , disable_weights = True , ** kwargs )
29
38
key_name = next (iter (tokens ))
30
39
embed_count = 0
@@ -47,22 +56,23 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
47
56
def __init__ (self , device = "cpu" , dtype = None , model_options = {}):
48
57
super ().__init__ (device = device , dtype = dtype , name = "qwen25_7b" , clip_model = Qwen25_7BVLIModel , model_options = model_options )
49
58
50
- def encode_token_weights (self , token_weight_pairs ):
59
+ def encode_token_weights (self , token_weight_pairs , template_end = - 1 ):
51
60
out , pooled , extra = super ().encode_token_weights (token_weight_pairs )
52
61
tok_pairs = token_weight_pairs ["qwen25_7b" ][0 ]
53
62
count_im_start = 0
54
- for i , v in enumerate (tok_pairs ):
55
- elem = v [0 ]
56
- if not torch .is_tensor (elem ):
57
- if isinstance (elem , numbers .Integral ):
58
- if elem == 151644 and count_im_start < 2 :
59
- template_end = i
60
- count_im_start += 1
61
-
62
- if out .shape [1 ] > (template_end + 3 ):
63
- if tok_pairs [template_end + 1 ][0 ] == 872 :
64
- if tok_pairs [template_end + 2 ][0 ] == 198 :
65
- template_end += 3
63
+ if template_end == - 1 :
64
+ for i , v in enumerate (tok_pairs ):
65
+ elem = v [0 ]
66
+ if not torch .is_tensor (elem ):
67
+ if isinstance (elem , numbers .Integral ):
68
+ if elem == 151644 and count_im_start < 2 :
69
+ template_end = i
70
+ count_im_start += 1
71
+
72
+ if out .shape [1 ] > (template_end + 3 ):
73
+ if tok_pairs [template_end + 1 ][0 ] == 872 :
74
+ if tok_pairs [template_end + 2 ][0 ] == 198 :
75
+ template_end += 3
66
76
67
77
out = out [:, template_end :]
68
78
0 commit comments