Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
efe81ee
LlmModule prefill refactor
kirklandsign Sep 9, 2025
015a6ab
Doing some rename
kirklandsign Sep 9, 2025
16b6d1c
Java layer no longer need a separate generateFromPos
kirklandsign Sep 9, 2025
e08df47
Remove generateFromPos API
kirklandsign Sep 9, 2025
e465fa2
Add audio input type
kirklandsign Sep 9, 2025
d62be5a
make private method now
kirklandsign Sep 10, 2025
3a4ffce
Merge remote-tracking branch 'origin/main' into start-pos-api-llava-7-9
kirklandsign Sep 22, 2025
392c157
Use prefill API
kirklandsign Sep 22, 2025
beb1784
Add a prefill() method for text llm runner
kirklandsign Sep 22, 2025
80c6378
Merge branch 'android-use-prefill-api' into start-pos-api-llava-7-9
kirklandsign Sep 22, 2025
1cd3582
Android use new prefill API
kirklandsign Sep 22, 2025
a271b07
Merge remote-tracking branch 'origin/main' into start-pos-api-llava-7-9
kirklandsign Sep 23, 2025
148bd91
QNN override
kirklandsign Sep 23, 2025
0a19212
fix
kirklandsign Sep 23, 2025
659007d
fix
kirklandsign Sep 23, 2025
f4de63a
fix
kirklandsign Sep 23, 2025
8745801
fix qnn compile
kirklandsign Sep 23, 2025
05378f5
Merge remote-tracking branch 'origin/main' into start-pos-api-llava-7-9
kirklandsign Sep 23, 2025
d020306
strange things
kirklandsign Sep 24, 2025
f8b1c47
let me have a try
kirklandsign Sep 24, 2025
5d9a9b9
i
kirklandsign Sep 24, 2025
f307c60
Revert "XNNPACK: Kleidi QP8 and SME2 (#13887)"
kirklandsign Sep 24, 2025
89fdd72
Update
kirklandsign Sep 29, 2025
c0cccb6
Merge remote-tracking branch 'origin/main' into audio-input-test
kirklandsign Sep 29, 2025
d58174a
Test
kirklandsign Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ install(TARGETS executorch_backends executorch_extensions executorch_kernels

if(EXECUTORCH_BUILD_EXECUTOR_RUNNER)
# Baseline libraries that executor_runner will link against.
set(_executor_runner_libs executorch extension_evalue_util
set(_executor_runner_libs executorch extension_evalue_util xnnpack_backend
extension_runner_util gflags executorch_backends
)

Expand Down
3 changes: 2 additions & 1 deletion backends/xnnpack/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ set(XNNPACK_ENABLE_AVX512VNNIGFNI
OFF
CACHE BOOL ""
)

set(XNNPACK_ENABLE_ARM_SME2
ON
OFF
CACHE BOOL ""
)
if(EXECUTORCH_XNNPACK_ENABLE_KLEIDI)
Expand Down
16 changes: 6 additions & 10 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
auto cvt_output_id = graph_node->output_id();

auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
assert(
dtype == DataType::xnn_datatype_qdint8 ||
dtype == DataType::xnn_datatype_qbint4);
for (auto value : *graph->xvalues()) {
if (value->xvalue_union_type() !=
fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
Expand All @@ -629,23 +632,16 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
return false;
}

// XNNPACK dtypes which have qp8 support.
const std::vector<DataType> supported_filter_dtypes = {
DataType::xnn_datatype_qbint4,
DataType::xnn_datatype_qcint4,
DataType::xnn_datatype_qcint8};

// Find if the convert output is going to the right linear node.
// Assuming if we can find one valid linear node, then we can use QP8
// for all the linear nodes consuming this convert output.
for (auto node : *graph->xnodes()) {
if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
auto linear_node = node->xnode_union_as_XNNFullyConnected();
if (linear_node->input1_id() == cvt_output_id) {
for (auto supported_filter_dtype : supported_filter_dtypes) {
if (check_dtype(linear_node->filter_id(), supported_filter_dtype)) {
return true;
}
if (check_dtype(
linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
return true;
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,13 @@ Error Runner<T>::generate_from_prompt_or_file(
return Error::Ok;
}

template <typename T>
::executorch::runtime::Error Runner<T>::prefill(
const std::string& prompt,
const executorch::extension::llm::GenerationConfig& config) {
return ::Error::NotImplemented;
}

template <typename T>
Result<DecoderModelVersion> Runner<T>::get_decoder_model_version() {
if (!is_loaded()) {
Expand Down
4 changes: 4 additions & 0 deletions examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class Runner : public executorch::extension::llm::IRunner {
const executorch::extension::llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const executorch::llm::Stats&)> stats_callback = {});

executorch::runtime::Error prefill(
const std::string& prompt,
const executorch::extension::llm::GenerationConfig& config = {}) override;
void stop() override {};
void reset() override {};
executorch::runtime::Result<DecoderModelVersion> get_decoder_model_version();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public int generate(
}

/**
* Prefill an LLaVA Module with the given images input.
* Prefill an multimodal Module with the given images input.
*
* @param image Input image as a byte array
* @param width Input image width
Expand All @@ -189,9 +189,9 @@ public long prefillImages(int[] image, int width, int height, int channels) {
private native int appendImagesInput(int[] image, int width, int height, int channels);

/**
* Prefill an LLaVA Module with the given text input.
* Prefill an multimodal Module with the given text input.
*
* @param prompt The text prompt to LLaVA.
* @param prompt The text prompt to multimodal model.
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
Expand All @@ -208,6 +208,35 @@ public long prefillPrompt(String prompt) {
// returns status
private native int appendTextInput(String prompt);

/**
* Prefill a multimodal Module with the given text input.
*
* @param prompt The text prompt to multimodal model.
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
* exposed to user.
* @throws RuntimeException if the prefill failed
*/
public int prefillAudio(String filePath) {
java.io.File file = new java.io.File(filePath);
try (java.io.FileInputStream fis = new java.io.FileInputStream(file)) {
byte[] fileBytes = new byte[(int) file.length()];
int bytesRead = fis.read(fileBytes);
if (bytesRead != fileBytes.length) {
throw new RuntimeException("Could not completely read file " + file.getName());
}
int nFloats = fileBytes.length / 4;
int batchSize = nFloats / (128 * 3000);
return appendAudioInput(fileBytes, batchSize, 128, 3000);
} catch (java.io.IOException e) {
throw new RuntimeException("Failed to read file: " + e);
}
}

// For Audio (option B), not RawAudio
// Use batch_size = ceil(n_floats / (n_bins * n_frames)), n_bins = 128, n_frames = 3000
// returns status
private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);

/**
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
*
Expand Down
164 changes: 152 additions & 12 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <fstream>

#include <executorch/extension/llm/runner/image.h>
#include <executorch/extension/llm/runner/irunner.h>
Expand Down Expand Up @@ -41,6 +42,7 @@

namespace llm = ::executorch::extension::llm;
using ::executorch::runtime::Error;
using executorch::extension::Module;

namespace {
bool utf8_check_validity(const char* str, size_t length) {
Expand Down Expand Up @@ -123,7 +125,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
std::unique_ptr<llm::IRunner> runner_;
std::unique_ptr<executorch::extension::llm::MultimodalRunner>
multi_modal_runner_;
std::vector<llm::MultimodalInput> prefill_inputs_;

public:
constexpr static auto kJavaDescriptor =
Expand Down Expand Up @@ -213,8 +214,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
std::vector<llm::MultimodalInput> inputs;
if (!prompt->toStdString().empty()) {
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
}
Expand All @@ -223,6 +223,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
.seq_len = seq_len,
.temperature = temperature_,
};
ET_LOG(Error, "Generating with multimodal runner %s", prompt->toStdString().c_str());
multi_modal_runner_->generate(
std::move(inputs),
config,
Expand All @@ -245,17 +246,28 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {

// Returns status_code
// Contract is valid within an AAR (JNI + corresponding Java code)
jint append_text_input(facebook::jni::alias_ref<jstring> prompt) {
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
return 0;
jint prefill_text_input(facebook::jni::alias_ref<jstring> prompt) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
runner_->prefill(prompt->toStdString(), {});
return 0;
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_->prefill(
{llm::MultimodalInput{prompt->toStdString()}});
return 0;
}
}

// Returns status_code
jint append_images_input(
jint prefill_images_input(
facebook::jni::alias_ref<jintArray> image,
jint width,
jint height,
jint channels) {
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(Error::InvalidArgument);
}
if (image == nullptr) {
return static_cast<jint>(Error::InvalidArgument);
}
std::vector<llm::Image> images;
if (image == nullptr) {
return static_cast<jint>(Error::EndOfMethod);
Expand All @@ -269,10 +281,136 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{std::move(image_data), width, height, channels};
prefill_inputs_.emplace_back(
llm::MultimodalInput{std::move(image_runner)});
multi_modal_runner_->prefill(
{llm::MultimodalInput{std::move(image_runner)}});
}

return 0;
}

llm::MultimodalInput processRawAudioFile(
const std::string& audio_path,
const std::string& processor_path) {
if (processor_path.empty()) {
ET_LOG(Error, "Processor path is required for raw audio processing");
throw std::runtime_error(
"Processor path is required for raw audio processing");
}

// Load the audio processor .pte.
std::unique_ptr<Module> processor_module;
try {
processor_module =
std::make_unique<Module>(processor_path, Module::LoadMode::File);
auto load_error = processor_module->load();
if (load_error != ::executorch::runtime::Error::Ok) {
ET_LOG(
Error,
"Failed to load processor module from: %s",
processor_path.c_str());
throw std::runtime_error("Failed to load processor module");
}
} catch (const std::exception& e) {
ET_LOG(Error, "Exception while loading processor module: %s", e.what());
throw std::runtime_error("Exception while loading processor module");
}

// Load the audio data from file.
std::ifstream f(audio_path, std::ios::binary | std::ios::ate);
if (!f.is_open()) {
ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str());
throw std::runtime_error("Failed to open audio file");
}

std::size_t n_floats = f.tellg() / sizeof(float);
f.seekg(0, std::ios::beg);

std::vector<float> audio_data(n_floats);
f.read(
reinterpret_cast<char*>(audio_data.data()),
audio_data.size() * sizeof(float));
f.close();

ET_LOG(
Info, "Loaded .bin file: %s, %zu floats", audio_path.c_str(), n_floats);

// Execute the processor
std::vector<executorch::aten::SizesType> tensor_shape = {
static_cast<executorch::aten::SizesType>(audio_data.size())};
auto input_tensor = executorch::extension::from_blob(
audio_data.data(), tensor_shape, ::executorch::aten::ScalarType::Float);

ET_LOG(Info, "Processing audio through processor module...");
auto result = processor_module->execute("forward", input_tensor);
if (!result.ok()) {
ET_LOG(Error, "Failed to execute processor's forward method");
throw std::runtime_error("Failed to execute processor forward method");
}

auto outputs = result.get();
if (outputs.empty()) {
ET_LOG(Error, "Processor returned no outputs");
throw std::runtime_error("Processor returned no outputs");
}

// Extract processed audio features
const auto& processed_tensor = outputs[0].toTensor();
const float* processed_data = processed_tensor.const_data_ptr<float>();
const auto& sizes = processed_tensor.sizes();

ET_LOG(
Info,
"Processed audio tensor shape: [%d, %d, %d]",
static_cast<int>(sizes[0]),
static_cast<int>(sizes[1]),
static_cast<int>(sizes[2]));

// Create Audio multimodal input from processed features
int32_t batch_size = static_cast<int32_t>(sizes[0]);
int32_t n_bins = static_cast<int32_t>(sizes[1]);
int32_t n_frames = static_cast<int32_t>(sizes[2]);
size_t total_elements = batch_size * n_bins * n_frames;
std::vector<float> audio_vec(processed_data, processed_data + total_elements);
auto processed_audio = ::executorch::extension::llm::Audio(
std::move(audio_vec), batch_size, n_bins, n_frames);
ET_LOG(
Info,
"Created processed Audio: batch_size=%d, n_bins=%d, n_frames=%d",
batch_size,
n_bins,
n_frames);
return ::executorch::extension::llm::make_audio_input(
std::move(processed_audio));
}

jint prefill_audio_input(
facebook::jni::alias_ref<jbyteArray> audio,
jint batch_size,
jint n_bins,
jint n_frames) {
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(Error::InvalidArgument);
}
if (audio == nullptr) {
return static_cast<jint>(Error::InvalidArgument);
}
// auto audio_size = audio->size();
// std::vector<uint8_t> audio_data(audio_size);
// if (audio_size != 0) {
// std::vector<jbyte> audio_data_jbyte(audio_size);
// audio->getRegion(0, audio_size, audio_data_jbyte.data());
// for (int i = 0; i < audio_size; i++) {
// audio_data[i] = audio_data_jbyte[i];
// }
// llm::Audio audio_input{std::move(audio_data), batch_size, n_bins, n_frames};

multi_modal_runner_->prefill(
{executorch::extension::llm::make_text_input("<s>[INST][BEGIN_AUDIO]"),
processRawAudioFile("/data/local/tmp/llama/audio.bin", "/data/local/tmp/llama/voxtral_preprocessor.pte"),
executorch::extension::llm::make_text_input(std::string("What can you tell me about this audio ") + "[/INST]")});
// }

ET_LOG(Error, "PREFILL AUDIO INPUT GOOD!!!!!!!!!!");
return 0;
}

Expand Down Expand Up @@ -309,9 +447,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
makeNativeMethod("stop", ExecuTorchLlmJni::stop),
makeNativeMethod("load", ExecuTorchLlmJni::load),
makeNativeMethod(
"appendImagesInput", ExecuTorchLlmJni::append_images_input),
"appendImagesInput", ExecuTorchLlmJni::prefill_images_input),
makeNativeMethod(
"appendTextInput", ExecuTorchLlmJni::prefill_text_input),
makeNativeMethod(
"appendTextInput", ExecuTorchLlmJni::append_text_input),
"appendAudioInput", ExecuTorchLlmJni::prefill_audio_input),
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),
});
}
Expand Down
11 changes: 11 additions & 0 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ class ET_EXPERIMENTAL IRunner {
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;

/**
* Prefill text inputs, for example to reload chat history.
* @param prompt Text prompt to prefill.
* @param config Configuration parameters (if non-zero num_bos and num_eos
* used)
* @return The error code. KV cache position is tracked internally in pos_.
*/
virtual ::executorch::runtime::Error prefill(
const std::string& prompt,
const GenerationConfig& config = {}) = 0;

/**
* Stop the generation process.
*/
Expand Down
Loading
Loading