@@ -95,6 +95,18 @@ def define_conversion_flags(model_name: str):
9595 return flags
9696
9797
98+ def _build_mask (mask_len , kv_cache_max_len , causal_mask_value ) -> torch .Tensor :
99+ if isinstance (mask_len , list ):
100+ return [
101+ _build_mask (i , kv_cache_max_len , causal_mask_value ) for i in mask_len
102+ ]
103+
104+ mask = torch .full (
105+ (mask_len , kv_cache_max_len ), causal_mask_value , dtype = torch .float32
106+ )
107+ return torch .triu (mask , diagonal = 1 ).unsqueeze (0 ).unsqueeze (0 )
108+
109+
98110def convert_to_tflite (
99111 pytorch_model : torch .nn .Module ,
100112 output_path : str ,
@@ -229,14 +241,15 @@ def _export_helper(
229241 torch .arange (0 , seq_len + pixel_seq_len , dtype = torch .int )
230242 )
231243
232- if export_config .prefill_mask is None :
233- prefill_masks = None
234- elif isinstance (export_config .prefill_mask , torch .Tensor ):
235- prefill_masks = [export_config .prefill_mask ]
236- elif isinstance (export_config .prefill_mask , list ):
237- prefill_masks = export_config .prefill_mask
238- else :
239- raise ValueError ('Prefill masks unrecognized.' )
244+ prefill_masks = None
245+ if flags .FLAGS .mask_as_input :
246+ prefill_masks = [
247+ _build_mask (
248+ flags .FLAGS .prefill_seq_lens ,
249+ flags .FLAGS .kv_cache_max_len ,
250+ config .get_causal_mask_value (),
251+ )
252+ ]
240253
241254 if prefill_masks :
242255 assert len (prefill_masks ) == len (prefill_seq_lens )
@@ -299,8 +312,17 @@ def _export_helper(
299312 'input_pos' : decode_input_pos ,
300313 'kv_cache' : decode_kv ,
301314 }
302- if export_config .decode_mask is not None :
303- sample_kwargs ['mask' ] = export_config .decode_mask
315+ if flags .FLAGS .mask_as_input :
316+ # Note that the decode mask is not a correct causal mask, but it is okay
317+ # for the conversion purpose because only the shape matters in conversion.
318+ # A correct causal mask of decode for a given token position of decode, it
319+ # should be built like:
320+ #
321+ # torch.triu(mask, diagonal=decode_position).unsqueeze(0).unsqueeze(0)
322+ #
323+ sample_kwargs ['mask' ] = _build_mask (
324+ 1 , flags .FLAGS .kv_cache_max_len , config .get_causal_mask_value ()
325+ )
304326 if lora is not None :
305327 sample_kwargs ['lora' ] = lora
306328
0 commit comments