1111import torch
1212
1313from examples .models .llama2 .llama_transformer import ModelArgs
14- from executorch .examples .models .model_factory import EagerModelFactory
15-
16- from .generation import LlamaRunner
14+ from executorch .examples .models .llama2 .export_llama_lib import (
15+ _prepare_for_llama_export ,
16+ build_args_parser as _build_args_parser ,
17+ )
18+ from executorch .examples .models .llama2 .runner .generation import LlamaRunner
19+ from executorch .extension .llm .export import LLMEdgeManager
1720
1821
1922class EagerLlamaRunner (LlamaRunner ):
@@ -25,21 +28,17 @@ def __init__(self, args):
2528 with open (args .params , "r" ) as f :
2629 params = json .loads (f .read ())
2730 model_args : ModelArgs = ModelArgs (
28- max_seq_len = args .max_len ,
31+ max_seq_len = args .max_seq_length ,
2932 max_batch_size = 1 ,
30- use_kv_cache = True ,
33+ use_kv_cache = args . use_kv_cache ,
3134 ** params ,
3235 )
33- super ().__init__ (tokenizer_path = args .tokenizer , model_args = model_args )
34- self .model , _ , _ , _ = EagerModelFactory .create_model (
35- "llama2" ,
36- "Llama2Model" ,
37- checkpoint = args .checkpoint ,
38- params = args .params ,
39- use_kv_cache = True ,
40- fairseq2 = False ,
41- max_seq_len = args .max_len ,
42- enable_dynamic_shape = True ,
36+ super ().__init__ (tokenizer_path = args .tokenizer_path , model_args = model_args )
37+ manager : LLMEdgeManager = _prepare_for_llama_export ("llama" , args )
38+ self .model = (
39+ manager .model .eval ().to (device = "cuda" )
40+ if torch .cuda .is_available ()
41+ else manager .model .eval ().to (device = "cpu" )
4342 )
4443
4544 def forward (
@@ -51,34 +50,7 @@ def forward(
5150
5251
5352def build_args_parser () -> argparse .ArgumentParser :
54- parser = argparse .ArgumentParser ()
55-
56- parser .add_argument (
57- "--checkpoint" ,
58- type = str ,
59- default = None ,
60- help = "path to model checkpoint file" ,
61- )
62-
63- parser .add_argument (
64- "--params" ,
65- type = str ,
66- default = None ,
67- help = "model params file" ,
68- )
69-
70- parser .add_argument (
71- "--max_len" ,
72- type = int ,
73- default = 128 ,
74- help = "Maximum length of the generated response sequence." ,
75- )
76-
77- parser .add_argument (
78- "--tokenizer" ,
79- type = str ,
80- default = None ,
81- )
53+ parser = _build_args_parser ()
8254
8355 parser .add_argument (
8456 "--prompt" ,
0 commit comments