Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/mediatek/executor_runner/mtk_llama_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class MTKLlamaRunner : public executorch::extension::llm::IRunner {
std::function<void(const std::string&)> token_callback);
std::unique_ptr<Tokenizer> load_tokenizer();

void reset() {}

private:
// model
const LlamaModelOptions modeloptions_;
Expand Down
14 changes: 2 additions & 12 deletions examples/qualcomm/oss_scripts/llama/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,17 +354,6 @@ Error Runner<T>::generate(
const llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
}

template <typename T>
Error Runner<T>::generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
// TODO: currently only support start_pos == 0
return generate_from_prompt_or_file(
prompt, false, config, token_callback, stats_callback);
}
Expand Down Expand Up @@ -435,7 +424,8 @@ Error Runner<T>::generate_from_prompt_or_file(
stats_.first_token_ms = time_in_ms();
stats_.prompt_eval_end_ms = time_in_ms();

// print the first token from prefill. No prev_token so use cur_token for it.
// print the first token from prefill. No prev_token so use cur_token for
// it.
if (token_callback) {
token_callback(
ET_UNWRAP_TOKENIZER(tokenizer_->decode(cur_token, cur_token)));
Expand Down
9 changes: 2 additions & 7 deletions examples/qualcomm/oss_scripts/llama/runner/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,15 @@ class Runner : public executorch::extension::llm::IRunner {
std::function<void(const std::string&)> token_callback = {},
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
override;
executorch::runtime::Error generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const executorch::extension::llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
override;

executorch::runtime::Error generate_from_prompt_or_file(
const std::string& prompt,
bool tokenized_prompt,
const executorch::extension::llm::GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const executorch::llm::Stats&)> stats_callback = {});
void stop() override {};
void reset() override {};
executorch::runtime::Result<DecoderModelVersion> get_decoder_model_version();

private:
Expand Down
25 changes: 2 additions & 23 deletions extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,39 +125,18 @@ class ET_EXPERIMENTAL IRunner {
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;

/**
* Generate text based on the provided prompt and generation config, from a
* given position in KV cache.
*
* @param prompt The input prompt to generate from
* @param start_pos The starting position in KV cache of the input. Note:
* Depending on the actual implementation, a runner may manage the position
* internally, and this may not be respected.
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
* @return Error::Ok if successful, an error otherwise
*/
virtual runtime::Error generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;
/**
* Stop the generation process.
*/
virtual void stop() = 0;

/**
* Force remove prefilled tokens and reset KV cache start position
*
* For some existing runners, overriding this method is not needed because
* start_pos is passed as an argument to generate_from_pos.
*
* This method removes the prefilled tokens from the KV cache and resets the
* start position to 0.
*/
virtual void reset() {};
virtual void reset() = 0;
};

} // namespace llm
Expand Down
13 changes: 2 additions & 11 deletions extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ TextLLMRunner::TextLLMRunner(
io_manager_(std::move(io_manager)),
text_token_generator_(std::move(text_token_generator)),
stats_(std::move(stats)),
pos_(0),
temperature_(temperature) {
// Note: This constructor assumes that text_prefiller and text_token_generator
// already have references to the Module and TextDecoderRunner they need
Expand Down Expand Up @@ -70,9 +71,8 @@ Error TextLLMRunner::load() {
ET_LOG(Info, format, __VA_ARGS__); \
}

Error TextLLMRunner::generate_from_pos(
Error TextLLMRunner::generate(
const std::string& prompt,
ET_UNUSED int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
Expand Down Expand Up @@ -217,15 +217,6 @@ Error TextLLMRunner::generate_from_pos(
return Error::Ok;
}

Error TextLLMRunner::generate(
const std::string& prompt,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
pos_ = 0;
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
}

Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
// Create a GenerationConfig for warmup
GenerationConfig config{
Expand Down
21 changes: 2 additions & 19 deletions extension/llm/runner/text_llm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,6 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;

/**
* Generate text based on the provided prompt and generation config, from a
* given position in KV cache.
*
* @param prompt The input prompt to generate from
* @param start_pos [Unused] The starting position in KV cache of the input,
* ignored because the runner manages the position internally.
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
* @return Error::Ok if successful, an error otherwise
*/
ET_DEPRECATED runtime::Error generate_from_pos(
const std::string& prompt,
ET_UNUSED int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;

/**
* @brief Warms up the model with a sample prompt
*
Expand All @@ -133,13 +114,15 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
::executorch::runtime::Error warmup(
const std::string& prompt,
int32_t max_new_tokens);

/**
* @brief Remove prefilled tokens and reset start position, and stats.
*
* This method removes the prefilled tokens from the KV cache and resets the
* start position to 0. It also clears the stats for previous runs.
*/
void reset() override;

/**
* @brief Stops the ongoing text generation process
*
Expand Down
Loading