Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
7f81e00
Changes to native runner to run tt
jackzhxng Oct 9, 2024
0b5a9a7
Add kwarg example inputs to eager model base
jackzhxng Sep 30, 2024
a9647d2
Create create new method for example kwarg inputs instead
jackzhxng Oct 7, 2024
fa3b1d2
Add kwarg example inputs to eager model base
jackzhxng Sep 30, 2024
e8715ba
Lint
jackzhxng Oct 8, 2024
a6f96a2
Accept model type parameter in export_llama
jackzhxng Oct 5, 2024
328c72c
Remove future implementation
jackzhxng Oct 5, 2024
ec80bba
Lint
jackzhxng Oct 15, 2024
c9bbe12
Create create new method for example kwarg inputs instead
jackzhxng Oct 7, 2024
99d5bfb
Accept model type parameter in export_llama
jackzhxng Oct 5, 2024
1fb2236
Torchtune llama3_2_vision model in ET, no quantization
jackzhxng Oct 5, 2024
e0c4b8a
Fix vision model example input
jackzhxng Oct 8, 2024
e145bd1
Lint
jackzhxng Oct 22, 2024
ed906cb
Kv cache
jackzhxng Oct 25, 2024
6dd47e7
Merge branch 'main' into jz/tt-llama
jackzhxng Oct 25, 2024
1825972
Update READMEs
jackzhxng Oct 25, 2024
196499a
Change model default arg
jackzhxng Oct 25, 2024
96ba40b
Update eager runner and eval llama
jackzhxng Oct 25, 2024
18a82e1
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 25, 2024
0f3035d
Fix tests
jackzhxng Oct 25, 2024
e677e14
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 25, 2024
b1f6678
Fix tests again
jackzhxng Oct 28, 2024
13d004b
Merge branch 'jz/tt-llama-rebased' into jz/tt-llama-2
jackzhxng Oct 28, 2024
c79b773
Strict = True
jackzhxng Oct 31, 2024
b8ff8e2
Things work
jackzhxng Oct 31, 2024
25ec7ce
Merge branch 'jz/tt-llama-rebased' into jz/native-runner-tt
jackzhxng Oct 31, 2024
6e38763
Clip logits if torchtune
jackzhxng Oct 31, 2024
7a7041d
Merge branch 'jz/tt-llama-2' into jz/native-runner-tt
jackzhxng Oct 31, 2024
96d5798
Fix
jackzhxng Oct 31, 2024
f275e2e
Kv cache by default is false
jackzhxng Nov 1, 2024
37011d3
Clean up
jackzhxng Nov 1, 2024
7d52002
Export model with KV cache + runner for Torchtune models
jackzhxng Nov 4, 2024
e44b259
Export with no kv cache + non-strict load checkpoint
jackzhxng Nov 6, 2024
de45c48
Strict = True
jackzhxng Oct 31, 2024
2fe7bd8
Merge branch 'main' into jz/tt-llama-2
jackzhxng Nov 13, 2024
64dcbda
Lint
jackzhxng Nov 13, 2024
a89d6b2
Fix merge
jackzhxng Nov 13, 2024
e1ec74c
Merge branch 'jz/tt-llama-2' into jz/native-runner-tt
jackzhxng Nov 13, 2024
84422d9
Fixes
jackzhxng Nov 13, 2024
1163769
Remove token count printing
jackzhxng Nov 13, 2024
a0e33d9
Merge branch 'jz/native-runner-tt' into jz/tt-llama-kv-cache
jackzhxng Nov 13, 2024
aa289ea
Fix faulty merge
jackzhxng Nov 13, 2024
eeeeb8a
Add runner
jackzhxng Nov 13, 2024
c80ce1c
Remove has_full_logits from llama runner
jackzhxng Nov 13, 2024
9bd405f
Lint
jackzhxng Nov 13, 2024
7507002
Modularize and update base eager runner
jackzhxng Nov 13, 2024
e5428de
Move to subdir
jackzhxng Nov 14, 2024
eefadaa
Merge branch 'jz/tt-llama-2' into jz/native-runner-tt
jackzhxng Nov 14, 2024
bf33485
Merge remote-tracking branch 'origin/main' into jz/tt-llama-2
jackzhxng Nov 14, 2024
9c5647c
Merge branch 'jz/tt-llama-2' into jz/native-runner-tt
jackzhxng Nov 14, 2024
f61a347
Tarun rev
jackzhxng Nov 14, 2024
a36703e
Merge branch 'jz/native-runner-tt' into jz/tt-llama-kv-cache
jackzhxng Nov 14, 2024
7a0101f
Add automatically generated export tests
jackzhxng Nov 14, 2024
9777e23
Fix internal pyre warning
jackzhxng Nov 14, 2024
2b9f281
Merge branch 'jz/tt-llama-2' into jz/native-runner-tt
jackzhxng Nov 14, 2024
f504cc5
Merge branch 'jz/native-runner-tt' into jz/tt-llama-kv-cache
jackzhxng Nov 14, 2024
1e26f60
Add executorch runner
jackzhxng Nov 15, 2024
b74e2c3
Merge remote-tracking branch 'origin/main' into jz/tt-llama-kv-cache
jackzhxng Nov 15, 2024
a168069
Add test for eager torchtune llama runner
jackzhxng Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions .ci/scripts/test_llama_3_2_vision.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -exu

ENABLE_KV_CACHE="${1:-false}"

if [[ "${ENABLE_KV_CACHE}" != "true" && "${ENABLE_KV_CACHE}" != "false" ]]; then
echo "Error: ENABLE_KV_CACHE must be 'true' or 'false'"
exit 1
fi

if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then
PYTHON_EXECUTABLE=python3
fi

download_dependencies() {
bash examples/models/llama3_2_vision/install_requirements.sh
tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct
}

run_and_verify_eager() {
NOW=$(date +"%H:%M:%S")
echo "Starting to test llama3_2_vision text decoder at ${NOW}"
if [[ ! -f "/tmp/Llama-3.2-11B-Vision-Instruct/original/consolidated.pth" ]]; then
echo "checkpoint (consolidated.pth) is missing."
exit 1
fi
if [[ ! -f "/tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model" ]]; then
echo "tokenizer.model is missing."
exit 1
fi

EAGER_RUNNER_ARGS="$PYTHON_EXECUTABLE -m examples.models.llama3_2_vision.runner.eager \
-c /tmp/Llama-3.2-11B-Vision-Instruct/original/consolidated.pth \
-t /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model \
-d fp32 \
--max_seq_length 32 \
--temperature 0 \
--show_tokens \
--prompt \"Once upon a time,\" > result.txt"

if [[ "${ENABLE_KV_CACHE}" == "true" ]]; then
EAGER_RUNNER_ARGS="${EAGER_RUNNER_ARGS} -kv"
fi

# Verify result.txt
RESULT=$(cat result.txt)
EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263"
if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then
echo "Actual result: ${RESULT}"
echo "Success"
exit 0
else
echo "Actual result: ${RESULT}"
echo "Failure; results not the same"
exit 1
fi
}

download_dependencies
run_and_verify
28 changes: 28 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,31 @@ jobs:

# run llama runner in eager mode
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama_runner_eager.sh

test-llama_3_2_vision_runner_eager-linux:
name: test-llama_3_2_vision_runner_eager-linux
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
strategy:
matrix:
enable_kv_cache: ["true", "false]
fail-fast: false
with:
runner: linux.24xlarge
docker-image: executorch-ubuntu-22.04-clang12
submodules: 'true'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
script: |
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
# install pybind
bash install_requirements.sh
# install llama requirements
bash examples/models/llama/install_requirements.sh

ENABLE_KV_CACHE=${{ matrix.enable_kv_cache }}

# run llama runner in eager mode
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama_runner_eager.sh "${ENABLE_KV_CACHE}"
12 changes: 7 additions & 5 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import argparse
import json
from typing import Optional
from typing import Optional, Type

import torch

Expand All @@ -33,7 +33,6 @@ def __init__(self, args):
max_batch_size=1,
use_kv_cache=args.use_kv_cache,
vocab_size=params["vocab_size"],
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
device="cuda" if torch.cuda.is_available() else "cpu",
)
manager: LLMEdgeManager = _prepare_for_llama_export(args)
Expand Down Expand Up @@ -79,11 +78,10 @@ def build_args_parser() -> argparse.ArgumentParser:
return parser


def main() -> None:
def execute_runner(runner_class: Type[LlamaRunner]) -> None:
parser = build_args_parser()
args = parser.parse_args()

runner = EagerLlamaRunner(args)
runner = runner_class(args)
generated_tokens = (
runner.chat_completion(temperature=args.temperature)
if args.chat
Expand All @@ -97,5 +95,9 @@ def main() -> None:
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")


def main() -> None:
execute_runner(EagerLlamaRunner)


if __name__ == "__main__":
main() # pragma: no cover
10 changes: 2 additions & 8 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
max_batch_size: int,
use_kv_cache: bool,
vocab_size: int,
has_full_logits: bool = False,
device: str = "cpu",
):
"""
Expand All @@ -65,14 +64,12 @@ def __init__(
max_batch_size: max batch size.
use_kv_cache: whether to use a KV cache.
vocab_size: number of items in the vocab.
has_full_logits: whether the model returns the full logits or only returns the last logit.
device: device to run the runner on.
"""
self.max_seq_len = max_seq_len
self.max_batch_size = max_batch_size
self.use_kv_cache = use_kv_cache
self.tokenizer = get_tokenizer(tokenizer_path)
self.has_full_logits = has_full_logits
self.device = device
assert vocab_size == self.tokenizer.n_words

Expand All @@ -93,7 +90,7 @@ def generate( # noqa: C901
echo: bool = False,
pos_base: int = 0,
) -> List[int]:
# prefill
# Prefill
logits = self.forward(
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
input_pos=(
Expand Down Expand Up @@ -128,10 +125,7 @@ def generate( # noqa: C901
)

# If the logits aren't already clipped to only contain the last logit, clip them.
if self.has_full_logits:
current_token = next_token(logits[:, -1, :], temperature, top_p)
else:
current_token = next_token(logits, temperature, top_p)
current_token = next_token(logits, temperature, top_p)
tokens.append(current_token)

if current_token == self.tokenizer.eos_id or (
Expand Down
1 change: 0 additions & 1 deletion examples/models/llama/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, args):
max_batch_size=1,
use_kv_cache=args.kv_cache,
vocab_size=params["vocab_size"],
has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS,
)
self.model = _load_for_executorch(args.pte)

Expand Down
53 changes: 53 additions & 0 deletions examples/models/llama3_2_vision/runner/eager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
from typing import Optional

import torch

from executorch.examples.models.llama.export_llama_lib import _prepare_for_llama_export
from executorch.examples.models.llama.runner.eager import execute_runner
from executorch.examples.models.llama3_2_vision.runner.generation import (
TorchTuneLlamaRunner,
)
from executorch.extension.llm.export import LLMEdgeManager


class EagerLlamaRunner(TorchTuneLlamaRunner):
"""
Runs llama in eager mode with provided checkpoint file.
"""

def __init__(self, args):
with open(args.params, "r") as f:
params = json.loads(f.read())
super().__init__(
tokenizer_path=args.tokenizer_path,
max_seq_len=args.max_seq_length,
max_batch_size=1,
use_kv_cache=args.use_kv_cache,
vocab_size=params["vocab_size"],
device="cuda" if torch.cuda.is_available() else "cpu",
)
manager: LLMEdgeManager = _prepare_for_llama_export(args)
self.model = manager.model.eval().to(device=self.device)

def forward(
self,
tokens: Optional[torch.LongTensor] = None,
input_pos: Optional[torch.LongTensor] = None,
mask: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
return self.model.forward(tokens=tokens, input_pos=input_pos, mask=mask)


def main() -> None:
execute_runner(EagerLlamaRunner)


if __name__ == "__main__":
main() # pragma: no cover
101 changes: 101 additions & 0 deletions examples/models/llama3_2_vision/runner/generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List

import torch
from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token


class TorchTuneLlamaRunner(LlamaRunner):
def __init__(
self,
tokenizer_path: str,
max_seq_len: int,
max_batch_size: int,
use_kv_cache: bool,
vocab_size: int,
device: str = "cpu",
):
super().__init__(
tokenizer_path,
max_seq_len,
max_batch_size,
use_kv_cache,
vocab_size,
device,
)

self.causal_mask = torch.tril(
torch.ones(
size=(max_seq_len, max_seq_len),
dtype=torch.bool,
)
)
self.input_pos = torch.arange(max_seq_len)

def generate( # noqa: C901
self,
prompt_tokens: List[int],
max_seq_len: int,
temperature: float = 0.8,
top_p: float = 0.9,
echo: bool = False,
) -> List[int]:
# Prefill
seq_len = len(prompt_tokens)
input_pos = self.input_pos[None, :seq_len]
mask = self.causal_mask[None, :seq_len]
if self.use_kv_cache:
logits = self.forward(
tokens=torch.tensor(
[prompt_tokens], dtype=torch.long, device=self.device
),
input_pos=input_pos,
mask=mask,
)
else:
logits = self.forward(
tokens=torch.tensor(
[prompt_tokens], dtype=torch.long, device=self.device
),
)

# Only need the last logit.
current_token = next_token(logits[:, -1, :], temperature, top_p)
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
tokens = prompt_tokens + [current_token]

while len(tokens) < max_seq_len:
mask = self.causal_mask[None, seq_len, None, :]
input_pos = self.input_pos[None, seq_len, None]
if self.use_kv_cache:
logits = self.forward(
tokens=torch.tensor(
[[current_token]], dtype=torch.long, device=self.device
),
input_pos=input_pos,
mask=mask,
)
else:
logits = self.forward(
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
)

# Only need the last logit.
current_token = next_token(logits[:, -1, :], temperature, top_p)
tokens.append(current_token)

if current_token == self.tokenizer.eos_id or (
hasattr(self.tokenizer, "stop_tokens")
and current_token in self.tokenizer.stop_tokens
):
break

print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
seq_len += 1

return tokens if echo else tokens[len(prompt_tokens) :]
Loading
Loading