1717
1818from absl import app
1919from ai_edge_torch .generative .examples .hammer import hammer
20- from ai_edge_torch .generative .layers import kv_cache
2120from ai_edge_torch .generative .utilities import converter
22- from ai_edge_torch .generative .utilities import export_config as export_cfg
23- import torch
24-
21+ from ai_edge_torch .generative .utilities import export_config
2522
2623flags = converter .define_conversion_flags ('hammer' )
27- ExportConfig = export_cfg .ExportConfig
28-
2924
3025_MODEL_SIZE = flags .DEFINE_enum (
3126 'model_size' ,
4035}
4136
4237
43- def _create_mask (mask_len , kv_cache_max_len ):
44- mask = torch .full (
45- (mask_len , kv_cache_max_len ), float ('-inf' ), dtype = torch .float32
46- )
47- mask = torch .triu (mask , diagonal = 1 ).unsqueeze (0 ).unsqueeze (0 )
48- return mask
49-
50-
51- def _create_export_config (
52- prefill_seq_lens : list [int ], kv_cache_max_len : int
53- ) -> ExportConfig :
54- """Creates the export config for the model."""
55- export_config = ExportConfig ()
56- if isinstance (prefill_seq_lens , list ):
57- prefill_mask = [_create_mask (i , kv_cache_max_len ) for i in prefill_seq_lens ]
58- else :
59- prefill_mask = _create_mask (prefill_seq_lens , kv_cache_max_len )
60-
61- export_config .prefill_mask = prefill_mask
62-
63- decode_mask = torch .full (
64- (1 , kv_cache_max_len ), float ('-inf' ), dtype = torch .float32
65- )
66- decode_mask = torch .triu (decode_mask , diagonal = 1 ).unsqueeze (0 ).unsqueeze (0 )
67- export_config .decode_mask = decode_mask
68- export_config .kvcache_layout = kv_cache .KV_LAYOUT_TRANSPOSED
69- return export_config
70-
71-
7238def main (_ ):
7339 pytorch_model = _BUILDER [_MODEL_SIZE .value ](
7440 flags .FLAGS .checkpoint_path , kv_cache_max_len = flags .FLAGS .kv_cache_max_len
@@ -80,11 +46,7 @@ def main(_):
8046 prefill_seq_len = flags .FLAGS .prefill_seq_lens ,
8147 quantize = flags .FLAGS .quantize ,
8248 lora_ranks = flags .FLAGS .lora_ranks ,
83- export_config = _create_export_config (
84- flags .FLAGS .prefill_seq_lens , flags .FLAGS .kv_cache_max_len
85- )
86- if flags .FLAGS .transpose_kv_cache
87- else ExportConfig (),
49+ export_config = export_config .get_from_flags (),
8850 )
8951
9052
0 commit comments