1010
1111import torch
1212
13- from examples .models .llama2 .llama_transformer import ModelArgs
13+ from executorch . examples .models .llama2 .llama_transformer import ModelArgs
1414from executorch .extension .pybindings .portable_lib import _load_for_executorch
1515
1616# Load custom ops and quantized ops.
1717from executorch .extension .pybindings import portable_lib # noqa # usort: skip
1818
1919# Note: import this after portable_lib
20- from executorch .extension .llm .custom_ops import sdpa_with_kv_cache # noqa # usort: skip
20+ # from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
2121from executorch .kernels import quantized # noqa
2222
23- from .generation import LlamaRunner
23+ from executorch . examples . models . llama2 . runner .generation import LlamaRunner
2424
2525
2626class NativeLlamaRunner (LlamaRunner ):
@@ -35,7 +35,7 @@ def __init__(self, args):
3535 max_seq_len = args .max_len ,
3636 max_batch_size = 1 ,
3737 use_kv_cache = args .kv_cache ,
38- ** params ,
38+ vocab_size = params [ "vocab_size" ] ,
3939 )
4040 super ().__init__ (tokenizer_path = args .tokenizer , model_args = model_args )
4141 self .model = _load_for_executorch (args .pte )
@@ -45,11 +45,17 @@ def forward(
4545 tokens : Optional [torch .LongTensor ] = None ,
4646 input_pos : Optional [torch .LongTensor ] = None ,
4747 ) -> torch .Tensor :
48- return (
49- self .model .forward ((tokens , input_pos ))
50- if input_pos is not None
51- else self .model .forward ((tokens ,))
52- )[0 ]
48+ # TODO: in LlamaRunner there is a generate function that automatically generates
49+ # input_pos tensor and inputs it into the model. Atm TorchTune models use
50+ # kwargs for the input_pos, so we will need to make some changes. At least
51+ # for the time being, we can run the non-kv cache version of the Torchtune
52+ # model with just the tokens like below.
53+ return (self .model .forward ((tokens ,)))[0 ]
54+ # return (
55+ # self.model.forward((tokens, input_pos))
56+ # if input_pos is not None
57+ # else self.model.forward((tokens,))
58+ # )[0]
5359
5460
5561def build_args_parser () -> argparse .ArgumentParser :
0 commit comments