Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 69 additions & 4 deletions .ci/scripts/test_ane_static_llama.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,78 @@ pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama
# Download stories llama110m artifacts
download_stories_model_artifacts

# Test static ANE llama model
# Test static ANE llama model export
echo "Exporting static ANE llama model..."
python export_static_llm_coreml.py --checkpoint stories110M.pt --params params.json --output model.pte

# The ANE cannot run in github CI
# python run_static_llm.py --model model.pte --params params.json --tokenizer tokenizer.model --prompt "Once upon a time," --lookahead
# The ANE is not accessible in github CI, so we export with CPU to test runner
echo "Exporting CPU-only model for CI testing..."
python export_static_llm_coreml.py --checkpoint stories110M.pt --params params.json --output model_cpu.pte --cpu_only

popd

# Build the C++ runner
echo "Building C++ runner..."
BUILD_DIR="${EXECUTORCH_ROOT}/cmake-out"

# Clean build directory completely to avoid stale artifacts and generator conflicts
rm -rf "${BUILD_DIR}"

cmake -S "${EXECUTORCH_ROOT}" -B "${BUILD_DIR}" \
-DCMAKE_INSTALL_PREFIX="${BUILD_DIR}" \
-DCMAKE_BUILD_TYPE=Release \
-DEXECUTORCH_ENABLE_LOGGING=ON \
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON \
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
-DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \
-DEXECUTORCH_BUILD_COREML=ON \
-G Ninja

cmake --build "${BUILD_DIR}" -j --target run_static_llm_coreml --config Release

# TODO: enable runner once CoreML bug with caching is fixed
# # Run the C++ runner with the CPU model
# echo "Running C++ runner with CPU model..."
# RUNNER="${BUILD_DIR}/examples/apple/coreml/llama/runner/run_static_llm_coreml"
# MODEL_DIR="${EXECUTORCH_ROOT}/examples/apple/coreml/llama"

# # Run the model and capture full output for debugging
# FULL_OUTPUT=$("${RUNNER}" \
# --model "${MODEL_DIR}/model.pte" \
# --params "${MODEL_DIR}/params.json" \
# --tokenizer "${MODEL_DIR}/tokenizer.model" \
# --prompt "Once upon a time," \
# --max_new_tokens 50 2>&1)

# echo "Full output:"
# echo "${FULL_OUTPUT}"

# # Check that the model produced meaningful output
# # The output should contain: the prompt "Once upon a time," and the continuation including "there was"
# # Due to log interleaving, we check for individual key parts separately
# if [[ "${FULL_OUTPUT}" == *"Once upon a time,"* ]] && [[ "${FULL_OUTPUT}" == *"there"* ]] && [[ "${FULL_OUTPUT}" == *"was"* ]]; then
# echo "Output contains expected prompt and generated text"
# echo "C++ runner test passed!"
# else
# echo "ERROR: Output does not contain expected text"
# echo "Expected: 'Once upon a time,' followed by 'there' and 'was'"
# exit 1
# fi

# TODO: enable runner once CoreML bug with caching is fixed
# # Run lookahead decoding test (currently produces <unk> tokens on stories, but works with llama)
# echo "Running C++ runner with lookahead decoding..."
# "${RUNNER}" \
# --model "${MODEL_DIR}/model.pte" \
# --params "${MODEL_DIR}/params.json" \
# --tokenizer "${MODEL_DIR}/tokenizer.model" \
# --prompt "Once upon a time," \
# --max_new_tokens 50 \
# --lookahead

# Test export of deprecated model
pushd $EXECUTORCH_ROOT/examples/apple/coreml/llama
python export.py -n model.pte -p params.json -c stories110M.pt --seq_length 32 --max_seq_length 64 --dtype fp16 --coreml-quantize c4w --embedding-quantize 4,32

popd
4 changes: 2 additions & 2 deletions .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ jobs:
name: test-static-llama-ane
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
with:
runner: macos-m1-stable
runner: macos-15-xlarge
python-version: '3.11'
submodules: 'recursive'
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
Expand Down Expand Up @@ -839,7 +839,7 @@ jobs:
qwen3-0.6b|xnnpack|--quantize,
qwen3-1.7b|xnnpack|--quantize,
gemma3-1b|xnnpack|--quantize,
# phi4-mini|xnnpack|--quantize, transformers v5.0.0rc0 introduces a data-dependent branching in transformers/modeling_rope_utils.py:61
# phi4-mini|xnnpack|--quantize, transformers v5.0.0rc0 introduces a data-dependent branching in transformers/modeling_rope_utils.py:61
smollm2-135m|xnnpack|--quantize,
smollm3-3b|xnnpack|--quantize
]
Expand Down
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -936,12 +936,20 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
list(APPEND _executorch_extensions extension_training)
endif()

if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
list(APPEND _executorch_extensions extension_llm_runner)
endif()

# Static LLM CoreML runner for Apple platforms
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER
AND EXECUTORCH_BUILD_COREML
AND APPLE
)
add_subdirectory(
${CMAKE_CURRENT_SOURCE_DIR}/examples/apple/coreml/llama/runner
)
endif()
if(EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/asr/runner)
list(APPEND _executorch_extensions extension_asr_runner)
Expand Down
43 changes: 36 additions & 7 deletions examples/apple/coreml/llama/export_static_llm_coreml.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,19 @@ def remove_graph_break_(edge_manager):
edge_manager.exported_program().graph_module.graph.eliminate_dead_code()


def load_model(checkpoint_path: str, params_path: str, max_context_len: int):
def load_model(
checkpoint_path: str,
params_path: str,
max_context_len: int,
generate_full_logits: bool = True,
):
"""Load the model from checkpoint with static_mha attention type."""
with open(params_path, "r") as f:
params = json.loads(f.read())

# TODO: to support lookahead decoding, the static model outputs
# full logits, but if we are not using lookahead decoding, we can have a
# more efficient model by setting generate_full_logits=False and supplying the last
# valid token
args = ModelArgs(
max_context_len=max_context_len,
generate_full_logits=True,
generate_full_logits=generate_full_logits,
**params,
)
args.attention_type = "static_mha"
Expand Down Expand Up @@ -320,15 +321,39 @@ def main():
help="Disable graph breaks between transformer blocks",
)

# Output options
parser.add_argument(
"--no_generate_full_logits",
action="store_true",
help="Only generate logits for the last token position (more efficient, but no lookahead support).",
)

# Compute options
parser.add_argument(
"--cpu_only",
action="store_true",
help="Use CPU only (no ANE). Useful for CI testing where ANE is not accessible.",
)

args = parser.parse_args()

# Compute cache length

generate_full_logits = not args.no_generate_full_logits

print("Quantization and datatype:")
print(f"\tEmbedding quantize: {args.embedding_quantize}")
print(f"\tLinear quantize: {args.linear_quantize}")
print(f"\tDtype: {args.dtype}")

print("\nOutput configuration:")
print(f"\tGenerate full logits: {generate_full_logits}")
if not generate_full_logits:
print("\t(Lookahead decoding will NOT be supported)")

print("\nCompute configuration:")
print(f"\tCPU only: {args.cpu_only}")

cache_len = args.max_context_len - args.input_len
print("\nGeneration configuration:")
print(f"\tMax context length: {args.max_context_len}")
Expand All @@ -345,6 +370,7 @@ def main():
args.checkpoint,
args.params,
args.max_context_len,
generate_full_logits,
)
print(f"Model loaded: {model_args.n_layers} layers, {model_args.dim} dim")

Expand Down Expand Up @@ -453,13 +479,16 @@ def main():

# Setup CoreML partitioner
print("\nSetting up CoreML partitioner...")
compute_unit = (
ct.ComputeUnit.CPU_ONLY if args.cpu_only else ct.ComputeUnit.CPU_AND_NE
)
compile_specs = CoreMLBackend.generate_compile_specs(
minimum_deployment_target=ct.target.iOS18,
compute_precision={
torch.float16: ct.precision.FLOAT16,
torch.float32: ct.precision.FLOAT32,
}[float_dtype],
compute_unit=ct.ComputeUnit.CPU_AND_NE,
compute_unit=compute_unit,
model_type=CoreMLBackend.MODEL_TYPE.MODEL,
)
partitioner = CoreMLPartitioner(
Expand Down
59 changes: 58 additions & 1 deletion examples/apple/coreml/llama/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,64 @@ The static model has several ANE optimizations, including:
* Re-writing SDPA to avoid 5-D tensors to imporve performance. This also fixes an accuracy bug that was introduced in iOS 26 (addresses this: https://github.com/pytorch/executorch/issues/15833)


We are working on adding a C++ runner as well.
## C++ Runner

A C++ runner is also available for running static attention LLM models. The runner extends `TextDecoderRunner` from the ExecutorTorch LLM extension and manages KV cache I/O with smart_mask style cache updates.

### Building on macOS

The easiest way to build is using the provided build script:

```bash
cd examples/apple/coreml/llama/runner
./build_and_run.sh --help # Show options
./build_and_run.sh # Build and run with defaults
```

Or build manually from the executorch root directory using the macos preset:

```bash
cmake -S . -B cmake-out --preset macos
cmake --build cmake-out --config Release --target run_static_llm_coreml -j$(sysctl -n hw.ncpu)
```

The executable will be at: `cmake-out/examples/apple/coreml/llama/runner/Release/run_static_llm_coreml`

### Running

```bash
./cmake-out/examples/apple/coreml/llama/runner/Release/run_static_llm_coreml \
--model static_llm_coreml_model.pte \
--params /path/to/params.json \
--tokenizer /path/to/tokenizer.model \
--prompt "Once upon a time," \
--max_new_tokens 100 \
--input_len 32 \
--cache_len 992 \
--temperature 0.0
```

### Command-line Options

| Option | Description | Default |
|--------|-------------|---------|
| `--model` | Path to the .pte model file | (required) |
| `--params` | Path to params.json | (required) |
| `--tokenizer` | Path to tokenizer file | (required) |
| `--prompt` | Input prompt | (required) |
| `--max_new_tokens` | Maximum tokens to generate | 100 |
| `--input_len` | Input sequence length (must match export) | 32 |
| `--cache_len` | KV cache length (must match export) | 992 |
| `--temperature` | Sampling temperature (0 = greedy) | 0.0 |

### Features

The C++ runner:
- Extends `TextDecoderRunner` from `executorch/extension/llm/runner/`
- Manages KV cache I/O with smart_mask style cache updates
- Supports multiple tokenizer formats (HuggingFace JSON, TikToken, SentencePiece, BPE)
- Computes RoPE frequencies internally (Llama 3 style with base=500000)
- Reads model configuration from params.json


# Deprecated (export.py, run.py, and run_lookahead.py)
Expand Down
Loading
Loading