diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index fc4ff006a90..14df30779aa 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -464,6 +464,13 @@ Error Runner::generate_from_prompt_or_file( return Error::Ok; } +template +::executorch::runtime::Error Runner::prefill( + const std::string& prompt, + const executorch::extension::llm::GenerationConfig& config) { + return ::Error::NotImplemented; +} + template Result Runner::get_decoder_model_version() { if (!is_loaded()) { diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 9f290d79c75..41d1ae19bdc 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -79,6 +79,10 @@ class Runner : public executorch::extension::llm::IRunner { const executorch::extension::llm::GenerationConfig& config, std::function token_callback = {}, std::function stats_callback = {}); + + executorch::runtime::Error prefill( + const std::string& prompt, + const executorch::extension::llm::GenerationConfig& config = {}) override; void stop() override {}; void reset() override {}; executorch::runtime::Result get_decoder_model_version(); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index cabf30c42e4..68b96cb64ff 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -123,7 +123,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { std::unique_ptr runner_; std::unique_ptr multi_modal_runner_; - std::vector prefill_inputs_; public: constexpr static auto kJavaDescriptor = @@ -213,8 +212,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { facebook::jni::alias_ref callback, jboolean echo) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = prefill_inputs_; - prefill_inputs_.clear(); + std::vector inputs; if (!prompt->toStdString().empty()) { inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); } @@ -245,17 +243,28 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns status_code // Contract is valid within an AAR (JNI + corresponding Java code) - jint append_text_input(facebook::jni::alias_ref prompt) { - prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - return 0; + jint prefill_text_input(facebook::jni::alias_ref prompt) { + if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { + runner_->prefill(prompt->toStdString(), {}); + return 0; + } else if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { + multi_modal_runner_->prefill( + {llm::MultimodalInput{prompt->toStdString()}}); + return 0; + } } - // Returns status_code - jint append_images_input( + jint prefill_images_input( facebook::jni::alias_ref image, jint width, jint height, jint channels) { + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + return static_cast(Error::InvalidArgument); + } + if (image == nullptr) { + return static_cast(Error::InvalidArgument); + } std::vector images; if (image == nullptr) { return static_cast(Error::EndOfMethod); @@ -269,13 +278,39 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { image_data[i] = image_data_jint[i]; } llm::Image image_runner{std::move(image_data), width, height, channels}; - prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(image_runner)}); + multi_modal_runner_->prefill( + {llm::MultimodalInput{std::move(image_runner)}}); } return 0; } + jint prefill_audio_input( + facebook::jni::alias_ref audio, + jint batch_size, + jint n_channels, + jint n_samples) { + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + return static_cast(Error::InvalidArgument); + } + if (audio == nullptr) { + return static_cast(Error::InvalidArgument); + } + auto audio_size = audio->size(); + std::vector audio_data(audio_size); + if (audio_size != 0) { + std::vector audio_data_jint(audio_size); + audio->getRegion(0, audio_size, audio_data_jint.data()); + for (int i = 0; i < audio_size; i++) { + audio_data[i] = audio_data_jint[i]; + } + llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples}; + multi_modal_runner_->prefill( + {llm::MultimodalInput{std::move(audio_input)}}); + } + return 0; + } + void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); @@ -309,9 +344,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod("stop", ExecuTorchLlmJni::stop), makeNativeMethod("load", ExecuTorchLlmJni::load), makeNativeMethod( - "appendImagesInput", ExecuTorchLlmJni::append_images_input), + "appendImagesInput", ExecuTorchLlmJni::prefill_images_input), + makeNativeMethod( + "appendTextInput", ExecuTorchLlmJni::prefill_text_input), makeNativeMethod( - "appendTextInput", ExecuTorchLlmJni::append_text_input), + "appendAudioInput", ExecuTorchLlmJni::prefill_audio_input), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index ef93f32319c..6699234be4d 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -125,6 +125,17 @@ class ET_EXPERIMENTAL IRunner { std::function token_callback, std::function stats_callback) = 0; + /** + * Prefill text inputs, for example to reload chat history. + * @param prompt Text prompt to prefill. + * @param config Configuration parameters (if non-zero num_bos and num_eos + * used) + * @return The error code. KV cache position is tracked internally in pos_. + */ + virtual ::executorch::runtime::Error prefill( + const std::string& prompt, + const GenerationConfig& config = {}) = 0; + /** * Stop the generation process. */ diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 333716ac831..ec9c6c5242f 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -217,6 +217,28 @@ Error TextLLMRunner::generate( return Error::Ok; } +Error TextLLMRunner::prefill( + const std::string& prompt, + const GenerationConfig& config) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + + ::tokenizers::Result> encode_res = tokenizer_->encode( + prompt, + /*bos=*/config.num_bos, + /*eos=*/config.num_eos); + + ET_CHECK_TK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + return Error::Ok; +} + Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config{ diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 9dd99d82d59..865b8a3bd53 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -101,6 +101,17 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { std::function token_callback = {}, std::function stats_callback = {}) override; + /** + * Prefill text inputs, for example to reload chat history. + * @param prompt Text prompt to prefill. + * @param config Configuration parameters (if non-zero num_bos and num_eos + * used) + * @return The error code. KV cache position is tracked internally in pos_. + */ + ::executorch::runtime::Error prefill( + const std::string& prompt, + const GenerationConfig& config = {}) override; + /** * @brief Warms up the model with a sample prompt *