diff --git a/examples/models/llama/runner/runner.h b/examples/models/llama/runner/runner.h index 4524aa81aab..5b3bb010112 100644 --- a/examples/models/llama/runner/runner.h +++ b/examples/models/llama/runner/runner.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -26,7 +27,7 @@ namespace example { -class ET_EXPERIMENTAL Runner { +class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner { public: explicit Runner( const std::string& model_path, diff --git a/examples/models/llama/runner/targets.bzl b/examples/models/llama/runner/targets.bzl index de12dc4d106..8322f190b32 100644 --- a/examples/models/llama/runner/targets.bzl +++ b/examples/models/llama/runner/targets.bzl @@ -30,6 +30,7 @@ def define_common_targets(): # qnn_executorch_backend can be added below //executorch/backends/qualcomm:qnn_executorch_backend exported_deps = [ "//executorch/backends/xnnpack:xnnpack_backend", + "//executorch/extension/llm/runner:irunner", "//executorch/extension/llm/runner:stats", "//executorch/extension/llm/runner:text_decoder_runner" + aten_suffix, "//executorch/extension/llm/runner:text_prefiller" + aten_suffix, diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h new file mode 100644 index 00000000000..35d87e997a0 --- /dev/null +++ b/extension/llm/runner/irunner.h @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// An interface for LLM runners. Developers can create their own runner that +// implements their own load and generation logic to run the model. + +#pragma once + +#include +#include + +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +class ET_EXPERIMENTAL IRunner { + public: + virtual ~IRunner() = default; + + // Checks if the model is loaded. + virtual bool is_loaded() const = 0; + + // Load the model and tokenizer. + virtual ::executorch::runtime::Error load() = 0; + + // Generate the output tokens. + virtual ::executorch::runtime::Error generate( + const std::string& prompt, + int32_t seq_len, + std::function token_callback = {}, + std::function + stats_callback = {}, + bool echo = true, + bool warming = false) = 0; + + // Stop the generation. + virtual void stop() = 0; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index f20240956cb..aa42c22b1b9 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -1,6 +1,16 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") def define_common_targets(): + runtime.cxx_library( + name = "irunner", + exported_headers = [ + "irunner.h", + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + ) + runtime.cxx_library( name = "stats", exported_headers = [