-
Couldn't load subscription status.
- Fork 701
Runner changes for TorchTune Llama3.2 vision text decoder #6610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 37 commits
7f81e00
0b5a9a7
a9647d2
fa3b1d2
e8715ba
a6f96a2
328c72c
ec80bba
c9bbe12
99d5bfb
1fb2236
e0c4b8a
e145bd1
ed906cb
6dd47e7
1825972
196499a
96ba40b
18a82e1
0f3035d
e677e14
b1f6678
13d004b
c79b773
b8ff8e2
25ec7ce
6e38763
7a7041d
96d5798
f275e2e
37011d3
de45c48
2fe7bd8
64dcbda
a89d6b2
e1ec74c
84422d9
1163769
e5428de
eefadaa
bf33485
9c5647c
f61a347
7a0101f
9777e23
2b9f281
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,18 +10,22 @@ | |
|
|
||
| import torch | ||
|
|
||
| from examples.models.llama.llama_transformer import ModelArgs | ||
| from executorch.examples.models.llama.export_llama_lib import ( | ||
| EXECUTORCH_DEFINED_MODELS, | ||
| TORCHTUNE_DEFINED_MODELS, | ||
| ) | ||
|
|
||
| from executorch.extension.pybindings.portable_lib import _load_for_executorch | ||
|
|
||
| # Load custom ops and quantized ops. | ||
| from executorch.extension.pybindings import portable_lib # noqa # usort: skip | ||
|
|
||
| from executorch.examples.models.llama.runner.generation import LlamaRunner | ||
|
|
||
| # Note: import this after portable_lib | ||
| from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip | ||
| from executorch.kernels import quantized # noqa | ||
|
|
||
| from .generation import LlamaRunner | ||
|
|
||
|
|
||
| class NativeLlamaRunner(LlamaRunner): | ||
| """ | ||
|
|
@@ -31,13 +35,14 @@ class NativeLlamaRunner(LlamaRunner): | |
| def __init__(self, args): | ||
| with open(args.params, "r") as f: | ||
| params = json.loads(f.read()) | ||
| model_args: ModelArgs = ModelArgs( | ||
| super().__init__( | ||
| tokenizer_path=args.tokenizer, | ||
| max_seq_len=args.max_len, | ||
| max_batch_size=1, | ||
| use_kv_cache=args.kv_cache, | ||
| **params, | ||
| vocab_size=params["vocab_size"], | ||
| has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS, | ||
| ) | ||
| super().__init__(tokenizer_path=args.tokenizer, model_args=model_args) | ||
| self.model = _load_for_executorch(args.pte) | ||
|
|
||
| def forward( | ||
|
|
@@ -53,8 +58,15 @@ def forward( | |
|
|
||
|
|
||
| def build_args_parser() -> argparse.ArgumentParser: | ||
| # TODO: merge these with build_args_parser from export_llama_lib. | ||
| parser = argparse.ArgumentParser() | ||
|
|
||
| parser.add_argument( | ||
| "--model", | ||
| default="llama3", | ||
| choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS, | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "-f", | ||
| "--pte", | ||
|
|
@@ -89,7 +101,6 @@ def build_args_parser() -> argparse.ArgumentParser: | |
| parser.add_argument( | ||
| "-kv", | ||
| "--kv_cache", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we'd want the default to still be True? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah this was weird since "store_true" works by having a default of False, but the default is here set to True so it's just always True regardless of what you put |
||
| default=True, | ||
| action="store_true", | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from .model import Llama3_2Decoder | ||
|
|
||
| __all__ = [Llama3_2Decoder] |
Uh oh!
There was an error while loading. Please reload this page.