11# Copyright (c) Alibaba, Inc. and its affiliates.
2- from typing import Any , Dict , List , Optional
2+ from typing import Any , Dict , List , Literal , Optional
33
44import torch
55
66from ..base import Template
77from ..constant import MLLMTemplateType
88from ..register import TemplateMeta , register_template
99from ..template_inputs import StdTemplateInputs
10- from ..utils import findall
10+ from ..utils import Context , findall
1111
1212
1313class MolmoTemplate (Template ):
14- system = None
15- use_model = True
16- image_placeholder = ['<|image|>' ]
17- DEFAULT_IMAGE_PATCH_TOKEN = '<im_patch>'
18- DEFAULT_IM_START_TOKEN = '<im_start>'
19- DEFAULT_IM_END_TOKEN = '<im_end>'
20- DEFAULT_IM_COL_TOKEN = '<im_col>'
2114
22- def __init__ (self , * args , ** kwargs ):
23- Template .__init__ (self , * args , ** kwargs )
24- self .processor_kwargs = {
25- 'images_kwargs' : {
26- 'max_crops' : 12 ,
27- 'overlap_margins' : [4 , 4 ],
28- 'base_image_input_size' : [336 , 336 ],
29- 'image_token_length_w' : 12 ,
30- 'image_token_length_h' : 12 ,
31- 'image_patch_size' : 14 ,
32- 'image_padding_mask' : True ,
33- },
34- 'text_kwargs' : {
35- 'style' : 'long_caption' ,
36- 'system_prompt' : 'none' ,
37- 'message_format' : 'role' ,
38- 'always_start_with_space' : True ,
39- 'sequence_length' : 1536 ,
40- 'padding' : False ,
41- }
42- }
15+ def replace_tag (self , media_type : Literal ['image' , 'video' , 'audio' ], index : int ,
16+ inputs : StdTemplateInputs ) -> List [Context ]:
17+ return []
4318
4419 def _encode (self , inputs : StdTemplateInputs ) -> Dict [str , Any ]:
4520 encoded = super ()._encode (inputs )
4621 # image
47- raw_image = inputs .images
48- res = {}
22+ images_inputs = self .processor .process (images = inputs .images or None , text = '' )
23+ images_input_ids = images_inputs .pop ('input_ids' ).tolist ()
24+ user_token = self ._tokenize (' User' )
25+ assert len (user_token ) == 1
26+ idx = findall (images_input_ids , user_token [0 ])
27+ assert len (idx ) == 1
4928 labels = encoded ['labels' ]
50- if raw_image :
51- image_id = self .tokenizer .convert_tokens_to_ids (self .image_placeholder )
52- idx_list = findall (encoded ['input_ids' ], image_id )
53- res = self ._process_images (raw_image , encoded ['input_ids' ], idx_list , labels )
54- import numpy as np
55- if 'image_input_idx' in res :
56- # Shift patch mapping up by one since we added BOS
57- image_input_idx = res ['image_input_idx' ]
58- res ['image_input_idx' ] = np .where (image_input_idx < 0 , image_input_idx , image_input_idx + 1 )
59- encoded ['input_ids' ] = res .pop ('input_ids' ).tolist ()
60- if labels :
61- encoded ['labels' ] = [- 100 ] + res .pop ('labels' ) # add one label for BOS
62-
63- for k , v in res .items ():
64- res [k ] = torch .from_numpy (v ).unsqueeze (0 )
65- bos = self .tokenizer .bos_token_id or self .tokenizer .eos_token_id
66- encoded ['input_ids' ] = [bos ] + encoded ['input_ids' ]
67- res .update ({'input_ids' : encoded ['input_ids' ]})
68- # prepare meta inputs
69- encoded .update (self .prepare_meta_inputs (res ))
70-
29+ encoded ['input_ids' ] = images_input_ids [:idx [0 ]] + encoded ['input_ids' ]
30+ if labels :
31+ encoded ['labels' ] = [- 100 ] * idx [0 ] + labels
32+ if 'images' in images_inputs :
33+ images_inputs ['images' ] = images_inputs ['images' ].to (self .config .torch_dtype )
34+ encoded .update (images_inputs )
7135 return encoded
7236
73- def _process_images (self , images : List , tokens : List , idx_list : List = None , labels : List = None ) -> torch .Tensor :
74- from PIL .Image import Image
75- import numpy as np
76- if images is not None :
77- image_arrays = []
78- for image in images :
79- if isinstance (image , Image ):
80- image = image .convert ('RGB' )
81- image_arrays .append (np .array (image ))
82- else :
83- assert len (image .shape ) == 3 and image .shape [- 1 ] == 3
84- image_arrays .append (image .astype (np .uint8 ))
85- images = image_arrays
86- # For now only support inserting images at the start
87- if idx_list is None :
88- idx_list = [- 1 ] * len (images )
89- image_patch_token_id = self .processor .special_token_ids [self .DEFAULT_IMAGE_PATCH_TOKEN ]
90- image_col_token_id = self .processor .special_token_ids [self .DEFAULT_IM_COL_TOKEN ]
91- image_start_token_id = self .processor .special_token_ids [self .DEFAULT_IM_START_TOKEN ]
92- image_end_token_id = self .processor .special_token_ids [self .DEFAULT_IM_END_TOKEN ]
93- sequence_length = self .processor_kwargs ['text_kwargs' ]['sequence_length' ]
94- res = self .processor .image_processor .multimodal_preprocess (
95- images = images ,
96- image_idx = idx_list ,
97- tokens = np .asarray (tokens ).astype (np .int32 ),
98- sequence_length = sequence_length ,
99- image_patch_token_id = image_patch_token_id ,
100- image_col_token_id = image_col_token_id ,
101- image_start_token_id = image_start_token_id ,
102- image_end_token_id = image_end_token_id ,
103- ** self .processor_kwargs ['images_kwargs' ])
104- if labels is not None :
105- new_labels = []
106- cur_idx = 0
107- for input_id in res ['input_ids' ]:
108- if input_id in (image_start_token_id , image_end_token_id , image_col_token_id , image_patch_token_id ):
109- new_labels .append (- 100 )
110- if tokens [cur_idx ] == self .tokenizer .convert_tokens_to_ids (self .image_placeholder )[0 ]:
111- cur_idx += 1
112- else :
113- new_labels .append (labels [cur_idx ])
114- cur_idx += 1
115- res ['labels' ] = new_labels
116- return res
117-
118- def prepare_meta_inputs (self , data : Any ) -> Dict [str , Any ]:
119-
120- # prepare batch inputs
121- input_ids = torch .tensor (data ['input_ids' ]).unsqueeze (0 )
122- batch_size , seq_len = input_ids .shape
123- attention_mask = None
124- mask_len = seq_len
125- max_new_tokens = None
126- if not self .is_training :
127- generation_config = self .model .generation_config
128- max_new_tokens = generation_config .max_new_tokens
129- if not max_new_tokens :
130- max_new_tokens = 0
131- mask_len = mask_len + max_new_tokens if self .model .config .use_position_ids else mask_len
132- position_ids : Optional [torch .Tensor ] = None
133- append_last_valid_logits : Optional [torch .Tensor ] = None
134- if self .model .config .use_position_ids and attention_mask is None :
135- attention_mask = input_ids != - 1
136- position_ids = torch .clamp (torch .cumsum (attention_mask .to (torch .int32 ), dim = - 1 ) - 1 , min = 0 )
137- append_last_valid_logits = attention_mask .long ().sum (dim = - 1 ) - 1
138- if max_new_tokens :
139- attention_mask = torch .cat (
140- [attention_mask , attention_mask .new_ones ((batch_size , max_new_tokens ))],
141- dim = 1 ,
142- )
143- if attention_mask is not None :
144- assert attention_mask .shape == (batch_size , mask_len )
145- if self .is_training :
146- # no batch_size before data_collator
147- attention_mask = attention_mask .squeeze (0 )
148- position_ids = position_ids .squeeze (0 )
149- data .update ({
150- 'attention_mask' : attention_mask ,
151- 'position_ids' : position_ids ,
152- 'append_last_valid_logits' : append_last_valid_logits ,
153- })
154- if 'images' in data :
155- data ['images' ] = data ['images' ].to (self .model .dtype )
156- return data
37+ def generate (self , model , ** kwargs ):
38+ kwargs .pop ('attention_mask' , None )
39+ generation_config = kwargs .pop ('generation_config' )
40+ batch = {
41+ k : kwargs .pop (k , None )
42+ for k in ['input_ids' , 'attention_mask' , 'images' , 'image_input_idx' , 'image_masks' ]
43+ }
44+ return model .generate_from_batch (batch , generation_config , ** kwargs )
15745
15846 def _data_collator (self , batch : List [Dict [str , Any ]], * , padding_to : Optional [int ] = None ) -> Dict [str , Any ]:
159- res = super ().data_collator (batch , padding_to = padding_to )
47+ res = super ()._data_collator (batch , padding_to = padding_to )
16048 # prepare batchfy inputs
161- keys = ['images' , 'image_input_idx' , 'image_masks' , 'append_last_valid_logits' ]
49+ keys = ['images' , 'image_input_idx' , 'image_masks' ]
50+ images_res = self .fetch_inputs (batch , keys )
16251 for key in keys :
163- batch_input = [b [key ] for b in batch if b .get (key ) is not None ]
164- res [key ] = torch .concat (batch_input )
165-
52+ val = images_res .get (key )
53+ if val :
54+ images_res [key ] = torch .stack (val )
55+ res .update (images_res )
16656 return res
16757
16858
@@ -171,8 +61,8 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
17161 MLLMTemplateType .molmo ,
17262 prefix = [],
17363 prompt = [' User: {{QUERY}} Assistant:' ],
174- chat_sep = [ '<|endoftext|>' ] ,
64+ chat_sep = None ,
17565 suffix = ['<|endoftext|>' ],
17666 template_cls = MolmoTemplate ,
177- placeholder_tokens = ['<|image| >' ],
67+ placeholder_tokens = ['<im_patch >' ],
17868 ))
0 commit comments