88from typing import Optional
99
1010import torch
11+ from executorch .examples .models .llama .config .llm_config import LlmConfig
1112
1213from executorch .examples .models .llama .export_llama_lib import _prepare_for_llama_export
1314from executorch .examples .models .llama .runner .eager import execute_runner
@@ -22,18 +23,23 @@ class EagerLlamaRunner(TorchTuneLlamaRunner):
2223 Runs llama in eager mode with provided checkpoint file.
2324 """
2425
25- def __init__ (self , args ):
26- with open (args .params , "r" ) as f :
26+ def __init__ (
27+ self ,
28+ llm_config : LlmConfig ,
29+ tokenizer_config_path : Optional [str ] = None ,
30+ use_attention_sink : bool = False ,
31+ ):
32+ with open (llm_config .base .params , "r" ) as f :
2733 params = json .loads (f .read ())
2834 super ().__init__ (
29- tokenizer_path = args .tokenizer_path ,
30- max_seq_len = args .max_seq_length ,
35+ tokenizer_path = llm_config . base .tokenizer_path ,
36+ max_seq_len = llm_config . export .max_seq_length ,
3137 max_batch_size = 1 ,
32- use_kv_cache = args .use_kv_cache ,
38+ use_kv_cache = llm_config . model .use_kv_cache ,
3339 vocab_size = params ["vocab_size" ],
3440 device = "cuda" if torch .cuda .is_available () else "cpu" ,
3541 )
36- manager : LLMEdgeManager = _prepare_for_llama_export (args )
42+ manager : LLMEdgeManager = _prepare_for_llama_export (llm_config )
3743 self .model = manager .model .eval ().to (device = self .device )
3844
3945 def forward (
0 commit comments