diff --git a/examples/models/llava/runner/llava_runner.cpp b/examples/models/llava/runner/llava_runner.cpp index aab5bfb4720..24809f12144 100644 --- a/examples/models/llava/runner/llava_runner.cpp +++ b/examples/models/llava/runner/llava_runner.cpp @@ -15,9 +15,7 @@ #include #include -#include #include -#include #include namespace llm = ::executorch::extension::llm; @@ -49,7 +47,8 @@ Error LlavaRunner::load() { // Load the text decoder runner text_decoder_runner_ = // @lint-ignore CLANGTIDY facebook-hte-Deprecated - std::make_unique(module_.get()); + std::make_unique( + module_.get(), io_manager_.get()); // @lint-ignore CLANGTIDY facebook-hte-Deprecated text_decoder_runner_->load(); diff --git a/examples/models/llava/runner/llava_text_decoder_runner.h b/examples/models/llava/runner/llava_text_decoder_runner.h index a5ad6fcab0a..09b8e82d49d 100644 --- a/examples/models/llava/runner/llava_text_decoder_runner.h +++ b/examples/models/llava/runner/llava_text_decoder_runner.h @@ -18,8 +18,10 @@ namespace example { class ET_EXPERIMENTAL LlavaTextDecoderRunner : public executorch::extension::llm::TextDecoderRunner { public: - explicit LlavaTextDecoderRunner(executorch::extension::Module* module) - : TextDecoderRunner(module) {} + explicit LlavaTextDecoderRunner( + executorch::extension::Module* module, + executorch::extension::llm::IOManager* io_manager) + : TextDecoderRunner(module, io_manager) {} inline executorch::runtime::Result step( executorch::extension::TensorPtr& tokens, diff --git a/extension/llm/runner/multimodal_runner.h b/extension/llm/runner/multimodal_runner.h index c17e039c11b..57ad2fd35d9 100644 --- a/extension/llm/runner/multimodal_runner.h +++ b/extension/llm/runner/multimodal_runner.h @@ -16,11 +16,10 @@ #include #include #include -#include -#include #include #include +#include #include #include #include @@ -41,6 +40,7 @@ class ET_EXPERIMENTAL MultimodalRunner { const float temperature = 0.8f) : temperature_(temperature), module_(std::make_unique(model_path, Module::LoadMode::File)), + io_manager_(std::make_unique()), tokenizer_path_(tokenizer_path) { ET_LOG( Info, @@ -127,6 +127,7 @@ class ET_EXPERIMENTAL MultimodalRunner { std::unique_ptr text_decoder_runner_; std::unique_ptr text_prefiller_; std::unique_ptr image_prefiller_; + std::unique_ptr io_manager_; std::unique_ptr text_token_generator_; std::string tokenizer_path_; std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index b6434d3e51d..c1d7ef48b17 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -36,6 +36,7 @@ def define_common_targets(): ":stats", "//executorch/kernels/portable/cpu/util:arange_util" + aten_suffix, "//executorch/extension/llm/sampler:sampler" + aten_suffix, + "//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix, "//executorch/extension/module:module" + aten_suffix, "//executorch/extension/tensor:tensor" + aten_suffix, ], @@ -102,6 +103,7 @@ def define_common_targets(): ":text_decoder_runner" + aten_suffix, ":text_prefiller" + aten_suffix, ":text_token_generator" + aten_suffix, + "//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix, "//pytorch/tokenizers:hf_tokenizer", "//pytorch/tokenizers:llama2c_tokenizer", "//pytorch/tokenizers:sentencepiece", diff --git a/extension/llm/runner/test/TARGETS b/extension/llm/runner/test/TARGETS index 7544d1607bd..8f758d21ea9 100644 --- a/extension/llm/runner/test/TARGETS +++ b/extension/llm/runner/test/TARGETS @@ -18,6 +18,7 @@ runtime.cxx_test( srcs = ["test_text_decoder_runner.cpp"], deps = [ "//executorch/extension/llm/runner:runner_lib", + "//executorch/extension/llm/runner/io_manager:io_manager", "//executorch/kernels/portable:generated_lib", "//executorch/runtime/core/exec_aten/testing_util:tensor_util", ], diff --git a/extension/llm/runner/test/test_text_decoder_runner.cpp b/extension/llm/runner/test/test_text_decoder_runner.cpp index c9a8de271f1..b23c5361ec3 100644 --- a/extension/llm/runner/test/test_text_decoder_runner.cpp +++ b/extension/llm/runner/test/test_text_decoder_runner.cpp @@ -7,6 +7,7 @@ * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */ +#include #include #include #include @@ -18,6 +19,7 @@ using namespace ::testing; using executorch::extension::Module; using executorch::extension::TensorPtr; +using executorch::extension::llm::IOManager; using executorch::extension::llm::TextDecoderRunner; using executorch::runtime::Error; using executorch::runtime::EValue; @@ -34,11 +36,14 @@ class TextDecoderRunnerTest : public Test { protected: void SetUp() override { mock_module_ = std::make_unique(); - runner_ = std::make_unique(mock_module_.get()); + io_manager_ = std::make_unique(); + runner_ = std::make_unique( + mock_module_.get(), io_manager_.get()); } std::unique_ptr mock_module_; std::unique_ptr runner_; + std::unique_ptr io_manager_; }; // Test logits_to_token() method with Float tensor @@ -150,15 +155,17 @@ TEST_F(TextDecoderRunnerTest, StepWithAllModels) { // Load the model auto module = std::make_unique(model_path); + auto load_result = module->load(); if (load_result != Error::Ok) { ADD_FAILURE() << "Failed to load model " << model_name << " from " << model_path << " with error: " << (int)load_result; continue; } - + std::unique_ptr io_manager = + std::make_unique(); // Create TextDecoderRunner - TextDecoderRunner runner(module.get()); + TextDecoderRunner runner(module.get(), io_manager.get()); auto runner_load_result = runner.load(); ASSERT_EQ(runner_load_result, Error::Ok) << "Failed to load runner for " << model_name; diff --git a/extension/llm/runner/test/test_text_llm_runner.cpp b/extension/llm/runner/test/test_text_llm_runner.cpp index 6896c56e961..b5302faebf4 100644 --- a/extension/llm/runner/test/test_text_llm_runner.cpp +++ b/extension/llm/runner/test/test_text_llm_runner.cpp @@ -7,6 +7,7 @@ * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */ +#include #include #include #include @@ -63,7 +64,7 @@ class MockModule : public ::executorch::extension::Module { class MockTextDecoderRunner : public TextDecoderRunner { public: - MockTextDecoderRunner() : TextDecoderRunner(nullptr) {} + MockTextDecoderRunner() : TextDecoderRunner(nullptr, nullptr) {} MOCK_METHOD( Result, step, @@ -219,6 +220,7 @@ TEST_F(RunnerTest, GenerateCallsCallbackExactlyMaxNewTokensTimes) { std::move(text_decoder_runner), std::unique_ptr<::executorch::extension::llm::TextPrefiller>( text_prefiller.release()), + std::make_unique(), std::move(text_token_generator), std::move(stats)); @@ -278,6 +280,7 @@ TEST_F(RunnerTest, WarmupCallsGenerateWithWarmingFlag) { std::move(text_decoder_runner), std::unique_ptr<::executorch::extension::llm::TextPrefiller>( text_prefiller.release()), + std::make_unique(), std::move(text_token_generator), std::move(stats)); @@ -312,6 +315,7 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) { std::move(text_decoder_runner), std::unique_ptr<::executorch::extension::llm::TextPrefiller>( text_prefiller.release()), + std::make_unique(), std::move(text_token_generator), std::move(stats)); @@ -356,6 +360,7 @@ TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) { std::move(text_decoder_runner), std::unique_ptr<::executorch::extension::llm::TextPrefiller>( text_prefiller.release()), + std::make_unique(), std::move(text_token_generator), std::move(stats)); diff --git a/extension/llm/runner/test/test_text_prefiller.cpp b/extension/llm/runner/test/test_text_prefiller.cpp index dc8bdc625e9..2e02fc2a406 100644 --- a/extension/llm/runner/test/test_text_prefiller.cpp +++ b/extension/llm/runner/test/test_text_prefiller.cpp @@ -24,7 +24,7 @@ using executorch::runtime::testing::TensorFactory; // Mock class for TextDecoderRunner class MockTextDecoderRunner : public TextDecoderRunner { public: - MockTextDecoderRunner() : TextDecoderRunner(nullptr) {} + MockTextDecoderRunner() : TextDecoderRunner(nullptr, nullptr) {} MOCK_METHOD( Result, step, diff --git a/extension/llm/runner/text_decoder_runner.cpp b/extension/llm/runner/text_decoder_runner.cpp index c50e815ab2a..bffd140eade 100644 --- a/extension/llm/runner/text_decoder_runner.cpp +++ b/extension/llm/runner/text_decoder_runner.cpp @@ -22,7 +22,8 @@ namespace llm { // NOTE: we observed ~2x loading performance increase on iPhone 15 // and a ~5% improvement on Galaxy S22 by switching to // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors. -TextDecoderRunner::TextDecoderRunner(Module* module) : module_(module) {} +TextDecoderRunner::TextDecoderRunner(Module* module, IOManager* io_manager) + : module_(module), io_manager_(io_manager) {} // This function is functional, meaning it shouldn't modify any state of the // input. It should be safe to call multiple times with the same inputs. The @@ -66,8 +67,22 @@ ::executorch::runtime::Result TextDecoderRunner::step( start_pos_tensor = from_blob( &start_pos, sizes_vec, ::executorch::aten::ScalarType::Long); } - auto outputs_res = module_->forward({tokens, start_pos_tensor}); + + std::vector inputs; + auto method_err = module_->method("forward"); + ET_CHECK_OK_OR_RETURN_ERROR(method_err.error()); + auto& method = *(method_err.get()); + + auto inputs_res = + io_manager_->prepare_decode(tokens, start_pos_tensor, method); + ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error()); + inputs = inputs_res.get(); + auto outputs_res = module_->forward(inputs); ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + + auto update_err = io_manager_->update_decode(method, outputs_res.get()); + ET_CHECK_OK_OR_RETURN_ERROR(update_err); + ET_CHECK_MSG( outputs_res.get().size() == 1, "More then one output returned from executing LLM."); diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index e930763668e..f583ed647a6 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -10,6 +10,7 @@ #pragma once +#include #include #include #include @@ -21,7 +22,7 @@ namespace llm { class ET_EXPERIMENTAL TextDecoderRunner { public: - explicit TextDecoderRunner(Module* module); + explicit TextDecoderRunner(Module* module, IOManager* io_manager); virtual ~TextDecoderRunner() = default; @@ -94,13 +95,14 @@ class ET_EXPERIMENTAL TextDecoderRunner { protected: /** - * Note: TextDecoderRunner does not own the Module instance. It is expected - * that the outer class (likely Runner) manages the lifecycle of the Module. - * This means that the responsibility for creating, maintaining, and + * Note: TextDecoderRunner does not own the Module or IOManager instance. It + * is expected that the outer class (likely Runner) manages the lifecycle of + * them. This means that the responsibility for creating, maintaining, and * destroying the Module lies outside of TextDecoderRunner. Ensure that the * Module remains valid for the duration of TextDecoderRunner's usage. */ Module* module_; + IOManager* io_manager_; bool should_stop_{false}; }; diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index cf55d98224a..4f89121111d 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -10,6 +10,7 @@ // A simple llama2 runner that includes preprocessing and post processing logic. // The module takes in a string as input and emits a string as output. +#include #include #include #include @@ -39,6 +40,7 @@ TextLLMRunner::TextLLMRunner( std::unique_ptr<::executorch::extension::Module> module, std::unique_ptr text_decoder_runner, std::unique_ptr text_prefiller, + std::unique_ptr io_manager, std::unique_ptr text_token_generator, std::unique_ptr stats, float temperature) @@ -47,6 +49,7 @@ TextLLMRunner::TextLLMRunner( module_(std::move(module)), text_decoder_runner_(std::move(text_decoder_runner)), text_prefiller_(std::move(text_prefiller)), + io_manager_(std::move(io_manager)), text_token_generator_(std::move(text_token_generator)), stats_(std::move(stats)), temperature_(temperature) { @@ -63,6 +66,14 @@ Error TextLLMRunner::load() { return Error::Ok; } ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); + auto method_res = module_->method("forward"); + + Program& program = *module_->program(); + + ET_CHECK_OK_OR_RETURN_ERROR(method_res.error()); + auto& forward = *(method_res.get()); + ET_CHECK_OK_OR_RETURN_ERROR(io_manager_->load(program, forward, forward)); ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok; } @@ -393,9 +404,13 @@ std::unique_ptr create_text_llm_runner( auto eos_ids = std::make_unique>( llm::get_eos_ids(tokenizer.get(), module.get())); + // Create IOManager + std::unique_ptr io_manager = std::make_unique(); + // Create text_decoder_runner. Use a shared_ptr so that it can be shared with // TextPrefiller and TextTokenGenerator - auto text_decoder_runner = std::make_unique(module.get()); + auto text_decoder_runner = + std::make_unique(module.get(), io_manager.get()); // Create text_prefiller auto text_prefiller = std::make_unique( @@ -420,6 +435,7 @@ std::unique_ptr create_text_llm_runner( std::move(module), std::move(text_decoder_runner), std::move(text_prefiller), + std::move(io_manager), std::move(text_token_generator), std::move(stats), temperature); diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 600d21a8801..c35f143d2e0 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -55,6 +55,7 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { std::unique_ptr<::executorch::extension::Module> module, std::unique_ptr text_decoder_runner, std::unique_ptr text_prefiller, + std::unique_ptr io_manager, std::unique_ptr text_token_generator, std::unique_ptr stats, float temperature = -1.0f); @@ -155,6 +156,7 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { // sure it outlives text_prefiller_ & // text_token_generator_. std::unique_ptr text_prefiller_; + std::unique_ptr io_manager_; std::unique_ptr text_token_generator_; // Stats diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 9af975657c0..43b3cd0f9b8 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -230,6 +230,7 @@ runtime::Error Module::load_method( ET_NODISCARD runtime::Result Module::method( const std::string& method_name) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); ET_CHECK_OR_RETURN_ERROR( methods_.count(method_name) > 0, InvalidArgument,