Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
190ac50
Update
larryliu0820 Sep 12, 2025
693c759
Make it work
larryliu0820 Sep 13, 2025
568f50c
Add readme
larryliu0820 Sep 13, 2025
72fc953
move test to test/
larryliu0820 Sep 13, 2025
e4ffbbe
Fix tests
larryliu0820 Sep 15, 2025
1e76ded
Fix
larryliu0820 Sep 15, 2025
6fc63d7
Rename test
larryliu0820 Sep 15, 2025
4c1c1d0
make_image_input take tensor
larryliu0820 Sep 16, 2025
a182c0b
More changes
larryliu0820 Sep 17, 2025
7b7f360
More changes
larryliu0820 Sep 18, 2025
5be86d2
Address comments
larryliu0820 Sep 18, 2025
5742baf
Add support for audio and token input
larryliu0820 Sep 19, 2025
34055ea
Remove notebook
larryliu0820 Sep 19, 2025
0c2bbdb
Remove utils.py
larryliu0820 Sep 20, 2025
06dcf71
Add CI jobs
larryliu0820 Sep 22, 2025
d11298e
Fix tests
larryliu0820 Sep 22, 2025
7c6266e
Update
larryliu0820 Sep 12, 2025
4744451
Make it work
larryliu0820 Sep 13, 2025
4a2169c
Add readme
larryliu0820 Sep 13, 2025
b9ffd48
move test to test/
larryliu0820 Sep 13, 2025
844b61a
Fix tests
larryliu0820 Sep 15, 2025
faab6b1
Fix
larryliu0820 Sep 15, 2025
ed3623e
Rename test
larryliu0820 Sep 15, 2025
e6e33a7
Address comments
larryliu0820 Sep 18, 2025
236dd41
Add support for audio and token input
larryliu0820 Sep 19, 2025
60200a8
Remove notebook
larryliu0820 Sep 19, 2025
41c0a02
Remove utils.py
larryliu0820 Sep 20, 2025
1fa8de0
Rebase
larryliu0820 Sep 22, 2025
59b6c98
Fix CI
larryliu0820 Sep 22, 2025
7ba7d88
Retry fixing CI
larryliu0820 Sep 22, 2025
b6d540d
More fixes?
larryliu0820 Sep 22, 2025
f20417c
Fix mac
larryliu0820 Sep 22, 2025
5222068
Lint
larryliu0820 Sep 22, 2025
9d5844a
Try to fix macos CI
larryliu0820 Sep 23, 2025
b27da8c
Fix macos CI 2
larryliu0820 Sep 23, 2025
4218c04
Try to fix windows
larryliu0820 Sep 23, 2025
6e71eec
Fix windows build
larryliu0820 Sep 23, 2025
79fd073
Fix wheel build
larryliu0820 Sep 23, 2025
22a2233
Lint
larryliu0820 Sep 23, 2025
f8ace7d
Remove llava
larryliu0820 Sep 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/optimum-executorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
828ae02053a6e0e20a2dfd6e737ba10c6f4dee6b
bd06b54e627fbfd354a2cffa4c80fb21883209a9
122 changes: 114 additions & 8 deletions .ci/scripts/test_huggingface_optimum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def cli_export(command, model_dir):


def check_causal_lm_output_quality(
model_id: str, generated_tokens: List[int], max_perplexity_threshold: float = 100.0
model_id: str,
generated_tokens: List[int],
max_perplexity_threshold: float = 100.0,
):
"""
Evaluates the quality of text generated by a causal language model by calculating its perplexity.
Expand All @@ -58,12 +60,24 @@ def check_causal_lm_output_quality(
"""
logging.info(f"Starting perplexity check with model '{model_id}' ...")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_cache=False,
torch_dtype=torch.bfloat16,
)
cls_name = AutoModelForCausalLM
if "llava" in model_id:
from transformers import LlavaForConditionalGeneration

cls_name = LlavaForConditionalGeneration
try:
model = cls_name.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_cache=False,
torch_dtype=torch.bfloat16,
)
except TypeError:
model = cls_name.from_pretrained(
model_id,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
)

with torch.no_grad():
outputs = model(input_ids=generated_tokens, labels=generated_tokens)
Expand Down Expand Up @@ -156,6 +170,86 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
assert check_causal_lm_output_quality(model_id, generated_tokens) is True


def test_llm_with_image_modality(
model_id, model_dir, recipe, *, quantize=True, run_only=False
):
command = [
"optimum-cli",
"export",
"executorch",
"--model",
model_id,
"--task",
"multimodal-text-to-text",
"--recipe",
recipe,
"--output_dir",
model_dir,
"--use_custom_sdpa",
"--use_custom_kv_cache",
"--qlinear",
"8da4w",
"--qembedding",
"8w",
]
if not run_only:
cli_export(command, model_dir)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(model_dir)

# input
processor = AutoProcessor.from_pretrained(model_id)
image_url = "https://llava-vl.github.io/static/images/view.jpg"
conversation = [
{
"role": "system",
"content": [
{
"type": "text",
"text": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
}
],
},
{
"role": "user",
"content": [
{"type": "image", "url": image_url},
{
"type": "text",
"text": "What are the things I should be cautious about when I visit here?",
},
],
},
]
inputs = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)

from executorch.extension.llm.runner import GenerationConfig, MultimodalRunner

runner = MultimodalRunner(f"{model_dir}/model.pte", f"{model_dir}/tokenizer.model")
generated_text = runner.generate_text_hf(
inputs,
GenerationConfig(max_new_tokens=128, temperature=0, echo=False),
processor.image_token_id,
)
print(f"\nGenerated text:\n\t{generated_text}")
# Free memory before loading eager for quality check
del runner
gc.collect()
assert (
check_causal_lm_output_quality(
model_id, tokenizer.encode(generated_text, return_tensors="pt")
)
is True
)


def test_fill_mask(model_id, model_dir, recipe, *, quantize=True, run_only=False):
command = [
"optimum-cli",
Expand Down Expand Up @@ -353,6 +447,9 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
required=False,
help="When provided, write the pte file to this directory. Otherwise, a temporary directory is created for the test.",
)
parser.add_argument(
"--run_only", action="store_true", help="Skip export and only run the test"
)
args = parser.parse_args()

_text_generation_mapping = {
Expand Down Expand Up @@ -384,8 +481,16 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
"vit": ("google/vit-base-patch16-224", test_vit),
}

_multimodal_model_mapping = {
"gemma3-4b": ("google/gemma-3-4b-it", test_llm_with_image_modality),
"llava": ("llava-hf/llava-1.5-7b-hf", test_llm_with_image_modality),
}

model_to_model_id_and_test_function = (
_text_generation_mapping | _mask_fill_mapping | _misc_model_mapping
_text_generation_mapping
| _mask_fill_mapping
| _misc_model_mapping
| _multimodal_model_mapping
)

if args.model not in model_to_model_id_and_test_function:
Expand All @@ -400,4 +505,5 @@ def test_vit(model_id, model_dir, recipe, *, quantize=False, run_only=False):
model_dir=tmp_dir if args.model_dir is None else args.model_dir,
recipe=args.recipe,
quantize=args.quantize,
run_only=args.run_only,
)
32 changes: 20 additions & 12 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -286,15 +286,20 @@ jobs:
# Test selective build
PYTHON_EXECUTABLE=python bash examples/selective_build/test_selective_build.sh "${BUILD_TOOL}"
test-llava-runner-linux:
name: test-llava-runner-linux
test-multimodal-linux:
if: ${{ !github.event.pull_request.head.repo.fork }}
name: test-multimodal-linux
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
contents: read
secrets: inherit
strategy:
fail-fast: false
matrix:
model: ["gemma3-4b"] # llava gives segfault so not covering.
with:
secrets-env: EXECUTORCH_HF_TOKEN
runner: linux.24xlarge
docker-image: ci-image:executorch-ubuntu-22.04-clang12
submodules: 'recursive'
Expand All @@ -305,17 +310,20 @@ jobs:
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"
echo "::group::Setup ExecuTorch"
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "cmake"
# install Llava requirements
bash examples/models/llama/install_requirements.sh
bash examples/models/llava/install_requirements.sh
# run python unittest
python -m unittest examples.models.llava.test.test_llava
# run e2e (export, tokenizer and runner)
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llava.sh
echo "::endgroup::"
echo "::group::Setup Huggingface"
pip install -U "huggingface_hub[cli]" accelerate
huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}
echo "::endgroup::"
echo "::group::Test ${{ matrix.model }}"
python .ci/scripts/test_huggingface_optimum_model.py --model ${{ matrix.model }} --quantize --recipe xnnpack
echo "::endgroup::"
test-moshi-linux:
name: test-moshi-linux
Expand Down
67 changes: 39 additions & 28 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -616,34 +616,45 @@ jobs:
bash .ci/scripts/test_torchao_huggingface_checkpoints.sh ${{ matrix.model }} ${{ matrix.test_with_runner && '--test_with_runner' || '' }}
# # TODO(jackzhxng): Runner consistently runs out of memory before test finishes. Try to find a more powerful runner.
# test-llava-runner-macos:
# name: test-llava-runner-macos
# uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
# strategy:
# fail-fast: false
# with:
# runner: macos-14-xlarge
# python-version: '3.11'
# submodules: 'recursive'
# ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
# timeout: 900
# script: |
# BUILD_TOOL=cmake

# bash .ci/scripts/setup-conda.sh
# # Setup MacOS dependencies as there is no Docker support on MacOS atm
# GITHUB_RUNNER=1 PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/setup-macos.sh --build-tool "${BUILD_TOOL}"

# # install Llava requirements
# ${CONDA_RUN} bash examples/models/llama/install_requirements.sh
# ${CONDA_RUN} bash examples/models/llava/install_requirements.sh

# # run python unittest
# ${CONDA_RUN} python -m unittest examples.models.llava.test.test_llava

# # run e2e (export, tokenizer and runner)
# PYTHON_EXECUTABLE=python ${CONDA_RUN} bash .ci/scripts/test_llava.sh
test-multimodal-macos:
if: ${{ !github.event.pull_request.head.repo.fork }}
name: test-multimodal-macos
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
permissions:
id-token: write
contents: read
secrets: inherit
strategy:
fail-fast: false
matrix:
model: ["gemma3-4b"] # llava gives segfault so not covering.
with:
secrets-env: EXECUTORCH_HF_TOKEN
runner: macos-15-xlarge
python-version: '3.11'
submodules: 'recursive'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
timeout: 90
script: |
echo "::group::Set up ExecuTorch"
bash .ci/scripts/setup-conda.sh
eval "$(conda shell.bash hook)"
# Install requirements
${CONDA_RUN} python install_executorch.py
echo "::endgroup::"
echo "::group::Set up Huggingface"
${CONDA_RUN} pip install -U "huggingface_hub[cli]" accelerate
${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
${CONDA_RUN} pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}
${CONDA_RUN} pip list
echo "::endgroup::"
echo "::group::Test ${{ matrix.model }}"
${CONDA_RUN} python .ci/scripts/test_huggingface_optimum_model.py --model ${{ matrix.model }} --quantize --recipe xnnpack
echo "::endgroup::"
test-qnn-model:
name: test-qnn-model
Expand Down
18 changes: 9 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -650,15 +650,6 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM)
list(APPEND _executorch_extensions tokenizers)
endif()

if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
list(APPEND _executorch_extensions extension_llm_runner)
endif()

if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
endif()

if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
install(
Expand Down Expand Up @@ -904,6 +895,15 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
list(APPEND _executorch_extensions extension_training)
endif()

if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
list(APPEND _executorch_extensions extension_llm_runner)
endif()

if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
endif()

if(EXECUTORCH_BUILD_KERNELS_LLM)
# TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops)
Expand Down
7 changes: 1 addition & 6 deletions examples/models/llava/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,4 @@

set -x

pip install transformers accelerate sentencepiece tiktoken

# Run llama2/install requirements for torchao deps
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

bash "$SCRIPT_DIR"/../llama/install_requirements.sh
pip install git+https://github.com/huggingface/optimum-executorch.git@d4d3046738ca31b5542506aaa76a28d540600227
3 changes: 1 addition & 2 deletions examples/models/llava/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ int32_t main(int32_t argc, char** argv) {
#endif
// Load tokenizer
std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
std::make_unique<tokenizers::Llama2cTokenizer>();
tokenizer->load(tokenizer_path);
::executorch::extension::llm::load_tokenizer(tokenizer_path);
if (tokenizer == nullptr) {
ET_LOG(Error, "Failed to load tokenizer from: %s", tokenizer_path);
return 1;
Expand Down
40 changes: 40 additions & 0 deletions extension/llm/runner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,43 @@ install(
if(BUILD_TESTING)
add_subdirectory(test)
endif()

# Python bindings for MultimodalRunner
if(EXECUTORCH_BUILD_PYBIND)
# Create the Python extension module for LLM runners
pybind11_add_module(
_llm_runner SHARED ${CMAKE_CURRENT_SOURCE_DIR}/pybindings.cpp
)

find_package_torch()
find_library(
TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib"
)
# Link with the extension_llm_runner library and its dependencies
target_link_libraries(
_llm_runner PRIVATE extension_llm_runner tokenizers::tokenizers
portable_lib ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES}
)

# Set properties for the Python extension
set_target_properties(
_llm_runner
PROPERTIES POSITION_INDEPENDENT_CODE ON
CXX_VISIBILITY_PRESET "hidden"
INTERPROCEDURAL_OPTIMIZATION TRUE
)
if(APPLE)
set(RPATH "@loader_path/../../pybindings")
else()
set(RPATH "$ORIGIN/../../pybindings")
endif()
set_target_properties(_llm_runner PROPERTIES INSTALL_RPATH ${RPATH})
# Add include directories
target_include_directories(
_llm_runner PRIVATE ${_common_include_directories} ${TORCH_INCLUDE_DIRS}
)

install(TARGETS _llm_runner
LIBRARY DESTINATION executorch/extension/llm/runner
)
endif()
Loading
Loading