Skip to content

Commit 423f65d

Browse files
helunwencserfacebook-github-bot
authored andcommitted
update eager runner to use same options for loading the model (#6257)
Summary: Pull Request resolved: #6257 imported-using-ghimport Test Plan: Imported from OSS Run the following command and make sure it generate the right result: ``` python -m examples.models.llama2.runner.eager \ -c /home/lunwenh/models/1B_Instruct/consolidated.00.pth \ -p /home/lunwenh/models/1B_Instruct/params.json \ -t /home/lunwenh/models/1B_Instruct/tokenizer.model \ --max_seq_length 128 \ -kv \ --prompt "<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a good assistant<|eot_id|><|start_header_id|>user<|end_header_id|> What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>" ``` ``` Response: The capital of France is Paris.<|eot_id|> Tokens: [791, 6864, 315, 9822, 374, 12366, 16134, 91, 68, 354, 851, 91, 29] ``` Reviewed By: mergennachin Differential Revision: D64442224 Pulled By: helunwencser fbshipit-source-id: bb8b11de6325ae76423b086491094a4444249553
1 parent 1d7c3ab commit 423f65d

File tree

1 file changed

+15
-43
lines changed

1 file changed

+15
-43
lines changed

examples/models/llama2/runner/eager.py

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
import torch
1212

1313
from examples.models.llama2.llama_transformer import ModelArgs
14-
from executorch.examples.models.model_factory import EagerModelFactory
15-
16-
from .generation import LlamaRunner
14+
from executorch.examples.models.llama2.export_llama_lib import (
15+
_prepare_for_llama_export,
16+
build_args_parser as _build_args_parser,
17+
)
18+
from executorch.examples.models.llama2.runner.generation import LlamaRunner
19+
from executorch.extension.llm.export import LLMEdgeManager
1720

1821

1922
class EagerLlamaRunner(LlamaRunner):
@@ -25,21 +28,17 @@ def __init__(self, args):
2528
with open(args.params, "r") as f:
2629
params = json.loads(f.read())
2730
model_args: ModelArgs = ModelArgs(
28-
max_seq_len=args.max_len,
31+
max_seq_len=args.max_seq_length,
2932
max_batch_size=1,
30-
use_kv_cache=True,
33+
use_kv_cache=args.use_kv_cache,
3134
**params,
3235
)
33-
super().__init__(tokenizer_path=args.tokenizer, model_args=model_args)
34-
self.model, _, _, _ = EagerModelFactory.create_model(
35-
"llama2",
36-
"Llama2Model",
37-
checkpoint=args.checkpoint,
38-
params=args.params,
39-
use_kv_cache=True,
40-
fairseq2=False,
41-
max_seq_len=args.max_len,
42-
enable_dynamic_shape=True,
36+
super().__init__(tokenizer_path=args.tokenizer_path, model_args=model_args)
37+
manager: LLMEdgeManager = _prepare_for_llama_export("llama", args)
38+
self.model = (
39+
manager.model.eval().to(device="cuda")
40+
if torch.cuda.is_available()
41+
else manager.model.eval().to(device="cpu")
4342
)
4443

4544
def forward(
@@ -51,34 +50,7 @@ def forward(
5150

5251

5352
def build_args_parser() -> argparse.ArgumentParser:
54-
parser = argparse.ArgumentParser()
55-
56-
parser.add_argument(
57-
"--checkpoint",
58-
type=str,
59-
default=None,
60-
help="path to model checkpoint file",
61-
)
62-
63-
parser.add_argument(
64-
"--params",
65-
type=str,
66-
default=None,
67-
help="model params file",
68-
)
69-
70-
parser.add_argument(
71-
"--max_len",
72-
type=int,
73-
default=128,
74-
help="Maximum length of the generated response sequence.",
75-
)
76-
77-
parser.add_argument(
78-
"--tokenizer",
79-
type=str,
80-
default=None,
81-
)
53+
parser = _build_args_parser()
8254

8355
parser.add_argument(
8456
"--prompt",

0 commit comments

Comments
 (0)