diff --git a/.ci/scripts/test_llama_3_2_vision.sh b/.ci/scripts/test_llama_3_2_vision.sh new file mode 100644 index 00000000000..ddcef703a6b --- /dev/null +++ b/.ci/scripts/test_llama_3_2_vision.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -exu + +ENABLE_KV_CACHE="${1:-false}" + +if [[ "${ENABLE_KV_CACHE}" != "true" && "${ENABLE_KV_CACHE}" != "false" ]]; then + echo "Error: ENABLE_KV_CACHE must be 'true' or 'false'" + exit 1 +fi + +if [[ -z "${PYTHON_EXECUTABLE:-}" ]]; then + PYTHON_EXECUTABLE=python3 +fi + +download_dependencies() { + bash examples/models/llama3_2_vision/install_requirements.sh + tune download meta-llama/Llama-3.2-11B-Vision-Instruct --output-dir /tmp/Llama-3.2-11B-Vision-Instruct +} + +run_and_verify_eager() { + NOW=$(date +"%H:%M:%S") + echo "Starting to test llama3_2_vision text decoder at ${NOW}" + if [[ ! -f "/tmp/Llama-3.2-11B-Vision-Instruct/original/consolidated.pth" ]]; then + echo "checkpoint (consolidated.pth) is missing." + exit 1 + fi + if [[ ! -f "/tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model" ]]; then + echo "tokenizer.model is missing." + exit 1 + fi + + EAGER_RUNNER_ARGS="$PYTHON_EXECUTABLE -m examples.models.llama3_2_vision.runner.eager \ + -c /tmp/Llama-3.2-11B-Vision-Instruct/original/consolidated.pth \ + -t /tmp/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model \ + -d fp32 \ + --max_seq_length 32 \ + --temperature 0 \ + --show_tokens \ + --prompt \"Once upon a time,\" > result.txt" + + if [[ "${ENABLE_KV_CACHE}" == "true" ]]; then + EAGER_RUNNER_ARGS="${EAGER_RUNNER_ARGS} -kv" + fi + + # Verify result.txt + RESULT=$(cat result.txt) + EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263" + if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then + echo "Actual result: ${RESULT}" + echo "Success" + exit 0 + else + echo "Actual result: ${RESULT}" + echo "Failure; results not the same" + exit 1 + fi +} + +download_dependencies +run_and_verify diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 1f5da06a920..9737ed55340 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -502,3 +502,31 @@ jobs: # run llama runner in eager mode PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama_runner_eager.sh + + test-llama_3_2_vision_runner_eager-linux: + name: test-llama_3_2_vision_runner_eager-linux + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + strategy: + matrix: + enable_kv_cache: ["true", "false] + fail-fast: false + with: + runner: linux.24xlarge + docker-image: executorch-ubuntu-22.04-clang12 + submodules: 'true' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake" + # install pybind + bash install_requirements.sh + # install llama requirements + bash examples/models/llama/install_requirements.sh + + ENABLE_KV_CACHE=${{ matrix.enable_kv_cache }} + + # run llama runner in eager mode + PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama_runner_eager.sh "${ENABLE_KV_CACHE}" diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index b2f29a8f6bb..2918e7b0503 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -6,7 +6,7 @@ import argparse import json -from typing import Optional +from typing import Optional, Type import torch @@ -33,7 +33,6 @@ def __init__(self, args): max_batch_size=1, use_kv_cache=args.use_kv_cache, vocab_size=params["vocab_size"], - has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS, device="cuda" if torch.cuda.is_available() else "cpu", ) manager: LLMEdgeManager = _prepare_for_llama_export(args) @@ -79,11 +78,10 @@ def build_args_parser() -> argparse.ArgumentParser: return parser -def main() -> None: +def execute_runner(runner_class: Type[LlamaRunner]) -> None: parser = build_args_parser() args = parser.parse_args() - - runner = EagerLlamaRunner(args) + runner = runner_class(args) generated_tokens = ( runner.chat_completion(temperature=args.temperature) if args.chat @@ -97,5 +95,9 @@ def main() -> None: print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") +def main() -> None: + execute_runner(EagerLlamaRunner) + + if __name__ == "__main__": main() # pragma: no cover diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 34c8d2f893a..46033705126 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -53,7 +53,6 @@ def __init__( max_batch_size: int, use_kv_cache: bool, vocab_size: int, - has_full_logits: bool = False, device: str = "cpu", ): """ @@ -65,14 +64,12 @@ def __init__( max_batch_size: max batch size. use_kv_cache: whether to use a KV cache. vocab_size: number of items in the vocab. - has_full_logits: whether the model returns the full logits or only returns the last logit. device: device to run the runner on. """ self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size self.use_kv_cache = use_kv_cache self.tokenizer = get_tokenizer(tokenizer_path) - self.has_full_logits = has_full_logits self.device = device assert vocab_size == self.tokenizer.n_words @@ -93,7 +90,7 @@ def generate( # noqa: C901 echo: bool = False, pos_base: int = 0, ) -> List[int]: - # prefill + # Prefill logits = self.forward( tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( @@ -128,10 +125,7 @@ def generate( # noqa: C901 ) # If the logits aren't already clipped to only contain the last logit, clip them. - if self.has_full_logits: - current_token = next_token(logits[:, -1, :], temperature, top_p) - else: - current_token = next_token(logits, temperature, top_p) + current_token = next_token(logits, temperature, top_p) tokens.append(current_token) if current_token == self.tokenizer.eos_id or ( diff --git a/examples/models/llama/runner/native.py b/examples/models/llama/runner/native.py index 06ee8e3e713..62757506f3b 100644 --- a/examples/models/llama/runner/native.py +++ b/examples/models/llama/runner/native.py @@ -41,7 +41,6 @@ def __init__(self, args): max_batch_size=1, use_kv_cache=args.kv_cache, vocab_size=params["vocab_size"], - has_full_logits=args.model in TORCHTUNE_DEFINED_MODELS, ) self.model = _load_for_executorch(args.pte) diff --git a/examples/models/llama3_2_vision/runner/eager.py b/examples/models/llama3_2_vision/runner/eager.py new file mode 100644 index 00000000000..c5d91013077 --- /dev/null +++ b/examples/models/llama3_2_vision/runner/eager.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +from typing import Optional + +import torch + +from executorch.examples.models.llama.export_llama_lib import _prepare_for_llama_export +from executorch.examples.models.llama.runner.eager import execute_runner +from executorch.examples.models.llama3_2_vision.runner.generation import ( + TorchTuneLlamaRunner, +) +from executorch.extension.llm.export import LLMEdgeManager + + +class EagerLlamaRunner(TorchTuneLlamaRunner): + """ + Runs llama in eager mode with provided checkpoint file. + """ + + def __init__(self, args): + with open(args.params, "r") as f: + params = json.loads(f.read()) + super().__init__( + tokenizer_path=args.tokenizer_path, + max_seq_len=args.max_seq_length, + max_batch_size=1, + use_kv_cache=args.use_kv_cache, + vocab_size=params["vocab_size"], + device="cuda" if torch.cuda.is_available() else "cpu", + ) + manager: LLMEdgeManager = _prepare_for_llama_export(args) + self.model = manager.model.eval().to(device=self.device) + + def forward( + self, + tokens: Optional[torch.LongTensor] = None, + input_pos: Optional[torch.LongTensor] = None, + mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + return self.model.forward(tokens=tokens, input_pos=input_pos, mask=mask) + + +def main() -> None: + execute_runner(EagerLlamaRunner) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/examples/models/llama3_2_vision/runner/generation.py b/examples/models/llama3_2_vision/runner/generation.py new file mode 100644 index 00000000000..e17760fd852 --- /dev/null +++ b/examples/models/llama3_2_vision/runner/generation.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +from executorch.examples.models.llama.runner.generation import LlamaRunner, next_token + + +class TorchTuneLlamaRunner(LlamaRunner): + def __init__( + self, + tokenizer_path: str, + max_seq_len: int, + max_batch_size: int, + use_kv_cache: bool, + vocab_size: int, + device: str = "cpu", + ): + super().__init__( + tokenizer_path, + max_seq_len, + max_batch_size, + use_kv_cache, + vocab_size, + device, + ) + + self.causal_mask = torch.tril( + torch.ones( + size=(max_seq_len, max_seq_len), + dtype=torch.bool, + ) + ) + self.input_pos = torch.arange(max_seq_len) + + def generate( # noqa: C901 + self, + prompt_tokens: List[int], + max_seq_len: int, + temperature: float = 0.8, + top_p: float = 0.9, + echo: bool = False, + ) -> List[int]: + # Prefill + seq_len = len(prompt_tokens) + input_pos = self.input_pos[None, :seq_len] + mask = self.causal_mask[None, :seq_len] + if self.use_kv_cache: + logits = self.forward( + tokens=torch.tensor( + [prompt_tokens], dtype=torch.long, device=self.device + ), + input_pos=input_pos, + mask=mask, + ) + else: + logits = self.forward( + tokens=torch.tensor( + [prompt_tokens], dtype=torch.long, device=self.device + ), + ) + + # Only need the last logit. + current_token = next_token(logits[:, -1, :], temperature, top_p) + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + tokens = prompt_tokens + [current_token] + + while len(tokens) < max_seq_len: + mask = self.causal_mask[None, seq_len, None, :] + input_pos = self.input_pos[None, seq_len, None] + if self.use_kv_cache: + logits = self.forward( + tokens=torch.tensor( + [[current_token]], dtype=torch.long, device=self.device + ), + input_pos=input_pos, + mask=mask, + ) + else: + logits = self.forward( + tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), + ) + + # Only need the last logit. + current_token = next_token(logits[:, -1, :], temperature, top_p) + tokens.append(current_token) + + if current_token == self.tokenizer.eos_id or ( + hasattr(self.tokenizer, "stop_tokens") + and current_token in self.tokenizer.stop_tokens + ): + break + + print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) + seq_len += 1 + + return tokens if echo else tokens[len(prompt_tokens) :] diff --git a/examples/models/llama3_2_vision/runner/native.py b/examples/models/llama3_2_vision/runner/native.py new file mode 100644 index 00000000000..9a28c94f9c2 --- /dev/null +++ b/examples/models/llama3_2_vision/runner/native.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import json +from typing import Optional + +import torch + +from executorch.examples.models.llama.export_llama_lib import ( + EXECUTORCH_DEFINED_MODELS, + TORCHTUNE_DEFINED_MODELS, +) +from executorch.examples.models.llama3_2_vision.runner.generation import ( + TorchTuneLlamaRunner, +) + +from executorch.extension.pybindings.portable_lib import _load_for_executorch + +# Load custom ops and quantized ops. +from executorch.extension.pybindings import portable_lib # noqa # usort: skip + +# Note: import this after portable_lib +from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip +from executorch.kernels import quantized # noqa + + +class NativeLlamaRunner(TorchTuneLlamaRunner): + """ + Runs llama via ExecuTorch with provided pte file. + """ + + def __init__(self, args): + with open(args.params, "r") as f: + params = json.loads(f.read()) + super().__init__( + tokenizer_path=args.tokenizer, + max_seq_len=args.max_len, + max_batch_size=1, + use_kv_cache=args.kv_cache, + vocab_size=params["vocab_size"], + ) + self.model = _load_for_executorch(args.pte) + self.use_kv_cache = args.kv_cache + + def forward( + self, + tokens: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + mask: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + return ( + self.model.forward((tokens, input_pos, mask)) + if self.use_kv_cache + else self.model.forward((tokens,)) + )[0] + + +def build_args_parser() -> argparse.ArgumentParser: + # TODO: merge these with build_args_parser from export_llama_lib. + parser = argparse.ArgumentParser() + + parser.add_argument( + "--model", + default="llama3", + choices=EXECUTORCH_DEFINED_MODELS + TORCHTUNE_DEFINED_MODELS, + ) + + parser.add_argument( + "-f", + "--pte", + type=str, + default=None, + help="path to exported executorch .pte file", + ) + + parser.add_argument( + "-p", "--params", type=str, default=None, help="model params file" + ) + + parser.add_argument( + "-t", + "--tokenizer", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + type=str, + default="Hello", + ) + + parser.add_argument( + "--temperature", + type=float, + default=0.6, + ) + + parser.add_argument( + "-kv", + "--kv_cache", + action="store_true", + ) + + parser.add_argument( + "--max_len", + type=int, + default=128, + help="Maximum length of the generated response sequence.", + ) + + return parser + + +def main() -> None: + parser = build_args_parser() + args = parser.parse_args() + runner = NativeLlamaRunner(args) + generated_tokens = runner.text_completion( + prompt=args.prompt, + temperature=args.temperature, + ) + print(f"Response: {generated_tokens}") + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/examples/models/llama3_2_vision/text_decoder/model.py b/examples/models/llama3_2_vision/text_decoder/model.py index 73d79bb08e0..65c11b20322 100644 --- a/examples/models/llama3_2_vision/text_decoder/model.py +++ b/examples/models/llama3_2_vision/text_decoder/model.py @@ -136,12 +136,13 @@ def __init__(self, **kwargs): self.model_ = prune_output_vocab(self.model_, output_prune_map) - # if self.use_kv_cache: - # print("Setting up KV cache on the model...") - # self.model_.setup_caches( - # batch_size=1, - # dtype=self.dtype, - # ) + if self.use_kv_cache: + print("Setting up KV cache on the model...") + self.model_.setup_caches( + batch_size=1, + dtype=self.dtype, + decoder_max_seq_len=self.max_seq_len, + ) def get_eager_model(self) -> torch.nn.Module: if self.dtype: @@ -150,25 +151,34 @@ def get_eager_model(self) -> torch.nn.Module: return self.model_.to(torch.float16) def get_example_inputs(self): - return (torch.ones(1, 64, dtype=torch.long),) + return (torch.ones(1, 32, dtype=torch.long),) def get_example_kwarg_inputs(self): - # TODO: add input_pos and mask when after making cache work. - return { - # "mask": self.causal_mask[None, 64, None, :], - # "encoder_input": None, - # "encoder_mask": None, - # "input_pos": self.input_pos[None, 64] - } + # For export we must use the prefill versions of the + # causal mask and input_pos. + if self.use_kv_cache: + return { + "input_pos": self.input_pos[None, :32], + "mask": self.causal_mask[None, :32], + # "encoder_input": None, + # "encoder_mask": None, + } + else: + return None def get_dynamic_shapes(self): batch_size = 1 dim_seq_len = torch.export.Dim("token_dim", min=1, max=self.max_seq_len) - dynamic_shapes = { - "tokens": {0: batch_size, 1: dim_seq_len}, - # "encoder_input": {0: 1, 1: dim_enc, 2: 4096}, - # "encoder_mask": {0: 1, 1: dim, 2: dim_enc}, - # "mask": {0: batch_size, 1: dim_seq_len, 2: self.max_seq_len}, - # "input_pos" : {0: batch_size, 1: dim_seq_len}, - } + if self.use_kv_cache: + dynamic_shapes = { + "tokens": {0: batch_size, 1: dim_seq_len}, + # "encoder_input": {0: 1, 1: dim_enc, 2: 4096}, + # "encoder_mask": {0: 1, 1: dim, 2: dim_enc}, + "mask": {0: batch_size, 1: dim_seq_len, 2: None}, + "input_pos": {0: batch_size, 1: dim_seq_len}, + } + else: + dynamic_shapes = { + "tokens": {0: batch_size, 1: dim_seq_len}, + } return dynamic_shapes diff --git a/examples/models/model_factory.py b/examples/models/model_factory.py index 5abe5efe462..5b66aef8de7 100644 --- a/examples/models/model_factory.py +++ b/examples/models/model_factory.py @@ -44,7 +44,7 @@ def create_model( model = model_class(**kwargs) example_kwarg_inputs = None dynamic_shapes = None - if hasattr(model, "get_example_kwarg_inputs()"): + if hasattr(model, "get_example_kwarg_inputs"): example_kwarg_inputs = model.get_example_kwarg_inputs() if hasattr(model, "get_dynamic_shapes"): dynamic_shapes = model.get_dynamic_shapes()