File tree Expand file tree Collapse file tree 3 files changed +12
-3
lines changed
examples/models/llama/runner Expand file tree Collapse file tree 3 files changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -34,7 +34,7 @@ def __init__(self, args):
3434 max_batch_size = 1 ,
3535 use_kv_cache = args .use_kv_cache ,
3636 vocab_size = params ["vocab_size" ],
37- has_full_logits = args .model in TORCHTUNE_DEFINED_MODELS
37+ has_full_logits = args .model in TORCHTUNE_DEFINED_MODELS ,
3838 device = "cuda" if torch .cuda .is_available () else "cpu" ,
3939 )
4040 manager : LLMEdgeManager = _prepare_for_llama_export (args )
Original file line number Diff line number Diff line change @@ -73,7 +73,6 @@ def __init__(
7373 has_full_logits: whether the model returns the full logits or only returns the last logit.
7474 device: device to run the runner on.
7575 """
76- self .model_name = model
7776 self .max_seq_len = max_seq_len
7877 self .max_batch_size = max_batch_size
7978 self .use_kv_cache = use_kv_cache
Original file line number Diff line number Diff line change 1010
1111import torch
1212
13+ from executorch .examples .models .llama .export_llama_lib import EXECUTORCH_DEFINED_MODELS , TORCHTUNE_DEFINED_MODELS
14+
1315from executorch .extension .pybindings .portable_lib import _load_for_executorch
1416
1517# Load custom ops and quantized ops.
1618from executorch .extension .pybindings import portable_lib # noqa # usort: skip
1719
18- from executorch .examples .models .llama2 .runner .generation import LlamaRunner
20+ from executorch .examples .models .llama .runner .generation import LlamaRunner
1921
2022# Note: import this after portable_lib
2123# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
@@ -36,6 +38,7 @@ def __init__(self, args):
3638 max_batch_size = 1 ,
3739 use_kv_cache = args .kv_cache ,
3840 vocab_size = params ["vocab_size" ],
41+ has_full_logits = args .model in TORCHTUNE_DEFINED_MODELS ,
3942 )
4043 self .model = _load_for_executorch (args .pte )
4144
@@ -58,8 +61,15 @@ def forward(
5861
5962
6063def build_args_parser () -> argparse .ArgumentParser :
64+ # TODO: merge these with build_args_parser from export_llama_lib.
6165 parser = argparse .ArgumentParser ()
6266
67+ parser .add_argument (
68+ "--model" ,
69+ default = "llama" ,
70+ choices = EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS ,
71+ )
72+
6373 parser .add_argument (
6474 "-f" ,
6575 "--pte" ,
You can’t perform that action at this time.
0 commit comments