@@ -57,7 +57,7 @@ def define_conversion_flags(model_name: str):
5757 )
5858 flags .DEFINE_string (
5959 'output_name_prefix' ,
60- f' { model_name } ' ,
60+ model_name ,
6161 'The prefix of the output tflite model name.' ,
6262 )
6363 flags .DEFINE_multi_integer (
@@ -91,6 +91,7 @@ def convert_to_tflite(
9191 output_name_prefix : str ,
9292 prefill_seq_len : Union [int , list [int ]],
9393 pixel_values_size : torch .Size = None ,
94+ pixel_seq_len : int = 0 ,
9495 quantize : bool = True ,
9596 config : cfg .ModelConfig = None ,
9697 lora_ranks : Optional [list [int ]] = None ,
@@ -133,12 +134,18 @@ def convert_to_tflite(
133134 use. If a list, the model will have multiple prefill signatures.
134135 pixel_values_size (torch.Size, optional): The size of pixel values to pass
135136 to the model. If None, the model is not expected to take pixel values.
137+ pixel_seq_len (int, optional): The length of pixel tokens, or pixel
138+ embeddings generated by the image encoder with pixel values. The actual
139+ length of prefill_seq_len will be added by pixel_seq_len when pixel
140+ values are passed.
136141 quantize (bool, optional): Whether the model should be quanized. Defaults
137142 to True.
138143 config (cfg.ModelConfig, optional): The model config used to configure KV
139144 cache. If None, it uses the config of the pytorch_model.
140145 lora_ranks (list[int], optional): The ranks of the LORA layers. If None,
141146 no LoRA signatures will be added.
147+ export_config (ExportConfig, optional): The export configuration. If None,
148+ it uses the default export configuration.
142149 """
143150 # pylint: disable=protected-access
144151 torch ._dynamo .config .cache_size_limit = 64
@@ -173,6 +180,7 @@ def convert_to_tflite(
173180 output_file ,
174181 prefill_seq_lens ,
175182 pixel_values_size ,
183+ pixel_seq_len ,
176184 quantize ,
177185 config ,
178186 loras ,
@@ -185,6 +193,7 @@ def _export_helper(
185193 output_file : str ,
186194 prefill_seq_lens : list [int ],
187195 pixel_values_size : torch .Size ,
196+ pixel_seq_len : int ,
188197 quantize : bool ,
189198 config : cfg .ModelConfig ,
190199 loras : list [None | lora_utils .LoRA ],
@@ -197,11 +206,18 @@ def _export_helper(
197206 prefill_tokens_list .append (torch .full ((1 , seq_len ), 0 , dtype = torch .int ))
198207 prefill_input_pos_list .append (torch .arange (0 , seq_len , dtype = torch .int ))
199208
200- prefill_pixel_values = (
201- torch .full (pixel_values_size , 0 , dtype = torch .float32 )
202- if pixel_values_size
203- else None
204- )
209+ prefill_pixel_values = None
210+ prefill_tokens_list_with_pixel = []
211+ prefill_input_pos_list_with_pixel = []
212+ if pixel_values_size is not None :
213+ prefill_pixel_values = torch .full (pixel_values_size , 0 , dtype = torch .float32 )
214+ for seq_len in prefill_seq_lens :
215+ prefill_tokens_list_with_pixel .append (
216+ torch .full ((1 , seq_len + pixel_seq_len ), 0 , dtype = torch .int )
217+ )
218+ prefill_input_pos_list_with_pixel .append (
219+ torch .arange (0 , seq_len + pixel_seq_len , dtype = torch .int )
220+ )
205221
206222 if export_config .prefill_mask is None :
207223 prefill_masks = None
@@ -238,13 +254,11 @@ def _export_helper(
238254 for lora in loras :
239255 for i in range (len (prefill_seq_lens )):
240256 prefill_seq_len = prefill_seq_lens [i ]
241- prefill_tokens = prefill_tokens_list [i ]
242- prefill_input_pos = prefill_input_pos_list [i ]
243257 prefill_signature_name = f'prefill_{ prefill_seq_len } '
244258
245259 sample_kwargs = {
246- 'tokens' : prefill_tokens ,
247- 'input_pos' : prefill_input_pos ,
260+ 'tokens' : prefill_tokens_list [ i ] ,
261+ 'input_pos' : prefill_input_pos_list [ i ] ,
248262 'kv_cache' : prefill_kv ,
249263 }
250264 if prefill_masks is not None :
@@ -261,13 +275,13 @@ def _export_helper(
261275 )
262276
263277 if prefill_pixel_values is not None :
278+ sample_kwargs ['tokens' ] = prefill_tokens_list_with_pixel [i ]
279+ sample_kwargs ['input_pos' ] = prefill_input_pos_list_with_pixel [i ]
280+ sample_kwargs ['pixel_values' ] = prefill_pixel_values
264281 converter .add_signature (
265282 prefill_signature_name + '_pixel' ,
266283 mod ,
267- sample_kwargs = {
268- ** sample_kwargs ,
269- 'pixel_values' : prefill_pixel_values ,
270- },
284+ sample_kwargs = sample_kwargs ,
271285 )
272286
273287 sample_kwargs = {
0 commit comments