Skip to content

Commit 78e57fd

Browse files
committed
Update
1 parent 82611e9 commit 78e57fd

File tree

9 files changed

+892
-26
lines changed

9 files changed

+892
-26
lines changed

backends/cuda/cuda_backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def preprocess(
154154
"aot_inductor.package_constants_in_so": False,
155155
# Store weight constants on disk in a binary blob
156156
"aot_inductor.package_constants_on_disk_format": "binary_blob",
157+
# Avoid issues like 'NoneType' object has no attribute 'reorder_iter_loops'
158+
"loop_ordering_after_fusion": False,
157159
# Enable maximum automatic tuning for optimal performance
158160
"max_autotune": True,
159161
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch

backends/cuda/cuda_partitioner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
PartitionResult,
1717
)
1818
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
19+
from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
1920
from torch.export.exported_program import ExportedProgram
2021

2122

@@ -56,6 +57,18 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5657
tag_constant_data(exported_program)
5758
tag_mutated_buffer(exported_program)
5859

60+
# Tag constant placeholders that have no users
61+
# tag_constant_data only tags constants that have users with delegation_tag
62+
# but we need to tag all constants for this partition
63+
for node in exported_program.graph.nodes:
64+
if node.op == "placeholder" and (
65+
is_param(exported_program, node)
66+
or is_buffer(exported_program, node)
67+
or is_lifted_tensor_constant(exported_program, node)
68+
):
69+
if "delegation_tag" not in node.meta:
70+
node.meta["delegation_tag"] = tag
71+
5972
return PartitionResult(
6073
tagged_exported_program=exported_program, partition_tags=partition_tags
6174
)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.24)
8+
project(whisper_runner)
9+
10+
set(CMAKE_CXX_STANDARD 17)
11+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
12+
13+
set(EXECUTORCH_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../../..")
14+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
15+
16+
# Let files say "include <executorch/path/to/header.h>"
17+
set(_common_include_directories ${EXECUTORCH_ROOT}/..)
18+
19+
# Need this for gflags for some reason
20+
set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags)
21+
find_package(gflags REQUIRED)
22+
23+
list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..)
24+
find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH)
25+
executorch_target_link_options_shared_lib(executorch)
26+
27+
set(link_libraries executorch gflags)
28+
set(_srcs multimodal.cpp)
29+
30+
list(
31+
APPEND
32+
link_libraries
33+
optimized_native_cpu_ops_lib
34+
quantized_ops_lib
35+
custom_ops
36+
cpublas
37+
eigen_blas
38+
)
39+
executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib)
40+
executorch_target_link_options_shared_lib(quantized_ops_lib)
41+
executorch_target_link_options_shared_lib(custom_ops)
42+
43+
# XNNPACK
44+
if(TARGET xnnpack_backend)
45+
set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod)
46+
if(TARGET kleidiai)
47+
list(APPEND xnnpack_backend_libs kleidiai)
48+
endif()
49+
list(APPEND link_libraries ${xnnpack_backend_libs})
50+
executorch_target_link_options_shared_lib(xnnpack_backend)
51+
endif()
52+
53+
# Add LLM runner and extension module
54+
if(NOT TARGET extension_llm_runner)
55+
message(
56+
FATAL_ERROR
57+
"ExecuTorch must be installed with EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER enabled."
58+
)
59+
endif()
60+
61+
# Needed for cpuinfo where it uses android specific log lib
62+
if(ANDROID)
63+
list(APPEND link_libraries log)
64+
endif()
65+
66+
# Add the required ExecuTorch extensions for multimodal LLM runner
67+
list(
68+
APPEND
69+
link_libraries
70+
extension_llm_runner
71+
extension_module
72+
extension_data_loader
73+
extension_tensor
74+
extension_flat_tensor
75+
)
76+
77+
# Link CUDA backend
78+
if(EXECUTORCH_BUILD_CUDA)
79+
find_package(CUDAToolkit REQUIRED)
80+
list(APPEND link_libraries aoti_cuda)
81+
executorch_target_link_options_shared_lib(aoti_cuda)
82+
endif()
83+
84+
if(EXECUTORCH_BUILD_METAL)
85+
list(APPEND link_libraries metal_backend)
86+
executorch_target_link_options_shared_lib(metal_backend)
87+
endif()
88+
89+
# Add tokenizers
90+
list(APPEND link_libraries tokenizers::tokenizers)
91+
92+
add_executable(whisper_runner runner.cpp main.cpp)
93+
94+
target_include_directories(whisper_runner PUBLIC ${_common_include_directories})
95+
96+
target_link_libraries(
97+
whisper_runner
98+
PUBLIC
99+
${link_libraries}
100+
)
101+
target_compile_options(whisper_runner PUBLIC ${_common_compile_options})

examples/models/whisper/README.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Whisper Runner
2+
3+
This directory hosts a lightweight C++ helper that drives Whisper models
4+
exported to ExecuTorch. The `WhisperRunner` owns the `Module` instance that
5+
wraps a bundled `.pte` program and optional `.ptd` weight file, loads the
6+
`encoder` and `text_decoder` methods, and exposes a `transcribe()` loop that
7+
streams decoded text pieces through a callback.
8+
9+
The runner assumes:
10+
- `model.pte` contains both Whisper encoder and decoder entry points named
11+
`encoder` and `text_decoder`.
12+
- External parameters (for example KV cache blocks) are stored in a companion
13+
`model.ptd`.
14+
- A tokenizer JSON compatible with the ExecuTorch tokenizers shim is available.
15+
16+
Audio preprocessing is not part of the runner itself. To transform raw audio
17+
into the mel features expected by the encoder, reuse the pattern in
18+
`examples/models/voxtral/multimodal.cpp`, which loads a `preprocessor.pte`
19+
module to generate the spectrogram tensor.
20+
21+
## Build
22+
23+
```bash
24+
cmake -G Ninja \
25+
-B cmake-out/examples/models/whisper \
26+
-S examples/models/whisper
27+
cmake --build cmake-out/examples/models/whisper -j
28+
```
29+
30+
The build produces a static library named `whisper_runner`. Link it into your
31+
application together with the standard ExecuTorch runtime libraries and the
32+
tokenizer target (`tokenizers::tokenizers`).
33+
34+
## Usage
35+
36+
```cpp
37+
#include <executorch/examples/models/whisper/runner.h>
38+
#include <executorch/extension/tensor/tensor_ptr.h>
39+
40+
using example::WhisperRunner;
41+
using example::WhisperTranscribeConfig;
42+
43+
WhisperRunner runner("model.pte", "model.ptd", "tokenizer.json");
44+
ET_CHECK_OK(runner.load());
45+
46+
// `features` is the mel spectrogram tensor produced by the preprocessor.
47+
executorch::aten::Tensor features = load_features_somehow();
48+
49+
WhisperTranscribeConfig config;
50+
config.max_new_tokens = 128; // stop after 128 generated tokens
51+
config.temperature = 0.7f; // optional: enable stochastic sampling
52+
53+
auto tokens_result = runner.transcribe(
54+
features,
55+
config,
56+
[](const std::string& piece) {
57+
std::cout << piece;
58+
});
59+
60+
if (!tokens_result.ok()) {
61+
ET_LOG(Error, "Transcription failed: %d", static_cast<int>(tokens_result.error()));
62+
}
63+
```
64+
65+
`transcribe()` returns the full token history (prompt + generated tokens) and
66+
invokes the callback every time a new token is emitted. Provide a non-empty
67+
`decoder_input_ids` vector if you want to seed the decoder with a custom prompt,
68+
and override `WhisperTranscribeConfig::eos_token_ids` when the model exposes
69+
custom termination ids.

0 commit comments

Comments
 (0)