Skip to content

Commit 7f81e00

Browse files
committed
Changes to native runner to run tt
1 parent a922b3d commit 7f81e00

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

examples/models/llama2/runner/native.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,17 @@
1010

1111
import torch
1212

13-
from examples.models.llama2.llama_transformer import ModelArgs
13+
from executorch.examples.models.llama2.llama_transformer import ModelArgs
1414
from executorch.extension.pybindings.portable_lib import _load_for_executorch
1515

1616
# Load custom ops and quantized ops.
1717
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
1818

1919
# Note: import this after portable_lib
20-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
20+
# from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
2121
from executorch.kernels import quantized # noqa
2222

23-
from .generation import LlamaRunner
23+
from executorch.examples.models.llama2.runner.generation import LlamaRunner
2424

2525

2626
class NativeLlamaRunner(LlamaRunner):
@@ -35,7 +35,7 @@ def __init__(self, args):
3535
max_seq_len=args.max_len,
3636
max_batch_size=1,
3737
use_kv_cache=args.kv_cache,
38-
**params,
38+
vocab_size=params["vocab_size"],
3939
)
4040
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
4141
self.model = _load_for_executorch(args.pte)
@@ -45,11 +45,17 @@ def forward(
4545
tokens: Optional[torch.LongTensor] = None,
4646
input_pos: Optional[torch.LongTensor] = None,
4747
) -> torch.Tensor:
48-
return (
49-
self.model.forward((tokens, input_pos))
50-
if input_pos is not None
51-
else self.model.forward((tokens,))
52-
)[0]
48+
# TODO: in LlamaRunner there is a generate function that automatically generates
49+
# input_pos tensor and inputs it into the model. Atm TorchTune models use
50+
# kwargs for the input_pos, so we will need to make some changes. At least
51+
# for the time being, we can run the non-kv cache version of the Torchtune
52+
# model with just the tokens like below.
53+
return (self.model.forward((tokens,)))[0]
54+
# return (
55+
# self.model.forward((tokens, input_pos))
56+
# if input_pos is not None
57+
# else self.model.forward((tokens,))
58+
# )[0]
5359

5460

5561
def build_args_parser() -> argparse.ArgumentParser:

0 commit comments

Comments
 (0)