Skip to content

Commit 96ed832

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Add a default image prefiller implementation (#13310)
Summary: As titled. I need to create an interface `IModule` for `Module` class to override, to make it test-able. Differential Revision: D80063769
1 parent 97a3aac commit 96ed832

File tree

9 files changed

+706
-11
lines changed

9 files changed

+706
-11
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// Given a image tensor, prefill the KV cache of LLaVA.
10+
11+
#include <executorch/extension/llm/runner/constants.h>
12+
#include <executorch/extension/llm/runner/image_prefiller.h>
13+
#include <executorch/extension/tensor/tensor.h>
14+
15+
namespace executorch::extension::llm {
16+
/**
17+
* Prefill an LLM Module with the given image input.
18+
* @param image The image input to LLaVa.
19+
* @param start_pos The starting position in KV cache of the input in the LLM
20+
* @return logits of the image prefill.
21+
*/
22+
::executorch::runtime::Result<uint64_t> ImagePrefiller::prefill(
23+
::executorch::extension::llm::Image& image,
24+
int64_t& start_pos) {
25+
auto image_tensor = executorch::extension::from_blob(
26+
image.data.data(),
27+
{3, image.height, image.width},
28+
::executorch::aten::ScalarType::Byte);
29+
// Run image encoder
30+
auto image_encoder_outputs =
31+
ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor));
32+
33+
// inputs:[start_pos, embeds]
34+
auto start_pos_tensor = executorch::extension::from_blob(
35+
&start_pos, {1}, ::executorch::aten::ScalarType::Long);
36+
37+
// Run text model
38+
auto outputs_res = ET_UNWRAP(module_->execute(
39+
kTextModelMethod, {start_pos_tensor, image_encoder_outputs[0]}));
40+
ET_CHECK_MSG(
41+
outputs_res[0].isTensor(),
42+
"Non Tensor Output returned from executing image prefill");
43+
44+
// Update the start_pos, which is only available inside this function.
45+
// outputs_res can have only one logits.
46+
start_pos += image_encoder_outputs[0].toTensor().size(1);
47+
48+
return logits_to_token(outputs_res[0].toTensor());
49+
}
50+
51+
/**
52+
* Load the Module for image prefill purpose.
53+
* @return The error code.
54+
*/
55+
::executorch::runtime::Error ImagePrefiller::load() {
56+
if (is_method_loaded()) {
57+
return ::executorch::runtime::Error::Ok;
58+
}
59+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod));
60+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod));
61+
return ::executorch::runtime::Error::Ok;
62+
}
63+
64+
/**
65+
* Check if the required methods in the Module is loaded.
66+
* @return True if the Module is loaded, false otherwise.
67+
*/
68+
bool ImagePrefiller::is_method_loaded() {
69+
::executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
70+
module_->method_names();
71+
if (methods_res.error() != ::executorch::runtime::Error::Ok) {
72+
ET_CHECK_MSG(false, "Failed to get method names");
73+
}
74+
std::unordered_set<std::string> methods = methods_res.get();
75+
bool methods_exist = methods.find(kImageEncoderMethod) != methods.end() &&
76+
methods.find(kTextModelMethod) != methods.end();
77+
if (!methods_exist) {
78+
for (const auto& method : methods) {
79+
ET_LOG(Error, "Method: %s", method.c_str());
80+
}
81+
ET_CHECK_MSG(
82+
methods_exist,
83+
"Missing required methods (%s, %s) in the model",
84+
kImageEncoderMethod,
85+
kTextModelMethod);
86+
}
87+
bool methods_loaded = module_->is_method_loaded(kImageEncoderMethod) &&
88+
module_->is_method_loaded(kTextModelMethod);
89+
return methods_loaded;
90+
}
91+
92+
} // namespace executorch::extension::llm

extension/llm/runner/image_prefiller.h

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#pragma once
1212

1313
#include <executorch/extension/llm/runner/image.h>
14+
#include <executorch/extension/llm/sampler/sampler.h>
1415
#include <executorch/extension/module/module.h>
1516
#include <executorch/runtime/platform/compiler.h>
1617

@@ -21,7 +22,7 @@ namespace llm {
2122
// Assuming kv cache and parallel prefill are enabled.
2223
class ET_EXPERIMENTAL ImagePrefiller {
2324
public:
24-
explicit ImagePrefiller(::executorch::extension::Module* module)
25+
explicit ImagePrefiller(::executorch::extension::ET_MODULE_NAMESPACE::IModule* module)
2526
: module_(module) {}
2627

2728
/**
@@ -31,17 +32,53 @@ class ET_EXPERIMENTAL ImagePrefiller {
3132
* It's passed as reference and will be updated inside this function.
3233
* @return The next token of the LLM Module after prefill.
3334
*/
34-
virtual ::executorch::runtime::Result<executorch::aten::Tensor> prefill(
35+
virtual ::executorch::runtime::Result<uint64_t> prefill(
3536
Image& image,
36-
int64_t& start_pos) = 0;
37+
int64_t& start_pos);
3738

38-
virtual ::executorch::runtime::Error load() = 0;
39-
virtual bool is_method_loaded() = 0;
39+
virtual ::executorch::runtime::Error load();
40+
virtual bool is_method_loaded();
4041

4142
virtual ~ImagePrefiller() = default;
4243

4344
protected:
44-
Module* module_;
45+
/**
46+
* Sample the next token from the logits tensor.
47+
* @param logits_tensor The logits tensor.
48+
* @param temperature The temperature parameter used to control randomness in
49+
* sampling.
50+
* @return The next token.
51+
*/
52+
inline uint64_t logits_to_token(
53+
const executorch::aten::Tensor& logits_tensor,
54+
const float temperature = 0.0f) {
55+
uint64_t result = 0;
56+
ET_SWITCH_THREE_TYPES(
57+
Float,
58+
Half,
59+
BFloat16,
60+
logits_tensor.scalar_type(),
61+
unused,
62+
"logits_to_token",
63+
CTYPE,
64+
[&]() {
65+
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
66+
// vocab_size], get the last logits, sample and return. Else the model
67+
// outputs the last logit, directly sample and return.
68+
auto* logits = logits_tensor.mutable_data_ptr<CTYPE>();
69+
ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1);
70+
if (logits_tensor.dim() == 3) {
71+
auto num_tokens = logits_tensor.size(1);
72+
logits += (num_tokens - 1) * vocab_size;
73+
}
74+
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
75+
Sampler sampler(vocab_size, temperature);
76+
result = sampler.sample(logits);
77+
});
78+
return result;
79+
}
80+
81+
::executorch::extension::ET_MODULE_NAMESPACE::IModule* module_;
4582
};
4683

4784
} // namespace llm

extension/llm/runner/targets.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,15 @@ def define_common_targets():
8484
runtime.cxx_library(
8585
name = "image_prefiller" + aten_suffix,
8686
exported_headers = ["image_prefiller.h", "image.h"],
87+
srcs = ["image_prefiller.cpp"],
8788
visibility = [
8889
"@EXECUTORCH_CLIENTS",
8990
],
9091
exported_deps = [
9192
":constants",
9293
"//executorch/extension/module:module" + aten_suffix,
94+
"//executorch/extension/tensor:tensor" + aten_suffix,
95+
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
9396
],
9497
)
9598

extension/llm/runner/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake)
1919

2020
set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp
2121
test_text_prefiller.cpp test_text_decoder_runner.cpp
22+
test_image_prefiller.cpp
2223
)
2324

2425
et_cxx_test(

extension/llm/runner/test/targets.bzl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,12 @@ def define_common_targets():
3636
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
3737
],
3838
)
39+
40+
runtime.cxx_test(
41+
name = "test_image_prefiller",
42+
srcs = ["test_image_prefiller.cpp"],
43+
deps = [
44+
"//executorch/extension/llm/runner:runner_lib",
45+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
46+
],
47+
)

0 commit comments

Comments
 (0)