Skip to content

Commit 75e6413

Browse files
authored
Android MultiModal JNI binding
Differential Revision: D61568605 Pull Request resolved: #4813
1 parent 56001c3 commit 75e6413

File tree

9 files changed

+150
-20
lines changed

9 files changed

+150
-20
lines changed

build/build_android_llm_demo.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ build_android_native_library() {
3030
-DCMAKE_TOOLCHAIN_FILE="${ANDROID_NDK}/build/cmake/android.toolchain.cmake" \
3131
-DANDROID_ABI="${ANDROID_ABI}" \
3232
-DANDROID_PLATFORM=android-23 \
33+
-DEXECUTORCH_ENABLE_LOGGING=ON \
34+
-DEXECUTORCH_LOG_LEVEL=Info \
3335
-DEXECUTORCH_BUILD_XNNPACK=ON \
36+
-DEXECUTORCH_XNNPACK_SHARED_WORKSPACE=ON \
3437
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
3538
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
3639
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
@@ -60,11 +63,14 @@ build_android_native_library() {
6063

6164
cmake --build "${CMAKE_OUT}"/examples/models/llama2 -j "${CMAKE_JOBS}" --config Release
6265

66+
6367
cmake extension/android \
6468
-DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \
6569
-DANDROID_ABI="${ANDROID_ABI}" \
6670
-DANDROID_PLATFORM=android-23 \
6771
-DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \
72+
-DEXECUTORCH_ENABLE_LOGGING=ON \
73+
-DEXECUTORCH_LOG_LEVEL=Info \
6874
-DEXECUTORCH_BUILD_LLAMA_JNI=ON \
6975
-DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \
7076
-DCMAKE_BUILD_TYPE=Release \

examples/models/llava/runner/llava_image_prefiller.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ class LlavaImagePrefiller : public ImagePrefiller {
2424
* @param start_pos The starting position in KV cache of the input in the LLM
2525
* @return logits of the image prefill.
2626
*/
27-
inline Result<exec_aten::Tensor> prefill(
28-
Image& image,
29-
int64_t start_pos = 0) {
27+
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t start_pos = 0)
28+
override {
3029
ManagedTensor managed_images(
3130
image.data.data(), {3, image.height, image.width}, ScalarType::Byte);
3231
// Run image encoder
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
runtime.cxx_library(
5+
name = "runner",
6+
srcs = ["llava_runner.cpp"],
7+
exported_headers = ["llava_runner.h", "llava_image_prefiller.h", "llava_text_decoder_runner.h"],
8+
visibility = [
9+
"@EXECUTORCH_CLIENTS",
10+
],
11+
exported_deps = [
12+
"//executorch/backends/xnnpack:xnnpack_backend",
13+
"//executorch/extension/llm/runner:runner_lib",
14+
"//executorch/extension/llm/tokenizer:bpe_tokenizer",
15+
"//executorch/extension/evalue_util:print_evalue",
16+
"//executorch/extension/runner_util:managed_tensor",
17+
"//executorch/extension/module:module",
18+
"//executorch/kernels/quantized:generated_lib",
19+
"//executorch/runtime/core/exec_aten:lib",
20+
"//executorch/runtime/core/exec_aten/util:tensor_util",
21+
],
22+
)

extension/android/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
7979
TARGET llama_runner PROPERTY IMPORTED_LOCATION ${LLAMA_RUNNER_PATH}
8080
)
8181

82+
add_subdirectory(
83+
${EXECUTORCH_ROOT}/examples/models/llava/runner
84+
${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llava/runner
85+
)
86+
8287
set(CUSTOM_OPS_PATH
8388
${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/custom_ops/libcustom_ops.a
8489
)
@@ -116,6 +121,7 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
116121
executorch_llama_jni
117122
${link_libraries}
118123
llama_runner
124+
llava_runner
119125
custom_ops
120126
cpublas
121127
eigen_blas

extension/android/jni/jni_layer_llama.cpp

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <vector>
1717

1818
#include <executorch/examples/models/llama2/runner/runner.h>
19+
#include <executorch/examples/models/llava/runner/llava_runner.h>
20+
#include <executorch/extension/llm/runner/image.h>
1921
#include <executorch/runtime/platform/log.h>
2022
#include <executorch/runtime/platform/platform.h>
2123
#include <executorch/runtime/platform/runtime.h>
@@ -90,21 +92,29 @@ class ExecuTorchLlamaJni
9092
: public facebook::jni::HybridClass<ExecuTorchLlamaJni> {
9193
private:
9294
friend HybridBase;
95+
int model_type_category_;
9396
std::unique_ptr<Runner> runner_;
97+
std::unique_ptr<MultimodalRunner> multi_modal_runner_;
9498

9599
public:
96100
constexpr static auto kJavaDescriptor =
97101
"Lorg/pytorch/executorch/LlamaModule;";
98102

103+
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
104+
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
105+
99106
static facebook::jni::local_ref<jhybriddata> initHybrid(
100107
facebook::jni::alias_ref<jclass>,
108+
jint model_type_category,
101109
facebook::jni::alias_ref<jstring> model_path,
102110
facebook::jni::alias_ref<jstring> tokenizer_path,
103111
jfloat temperature) {
104-
return makeCxxInstance(model_path, tokenizer_path, temperature);
112+
return makeCxxInstance(
113+
model_type_category, model_path, tokenizer_path, temperature);
105114
}
106115

107116
ExecuTorchLlamaJni(
117+
jint model_type_category,
108118
facebook::jni::alias_ref<jstring> model_path,
109119
facebook::jni::alias_ref<jstring> tokenizer_path,
110120
jfloat temperature) {
@@ -119,30 +129,72 @@ class ExecuTorchLlamaJni
119129
}
120130
#endif
121131

122-
runner_ = std::make_unique<Runner>(
123-
model_path->toStdString().c_str(),
124-
tokenizer_path->toStdString().c_str(),
125-
temperature);
132+
model_type_category_ = model_type_category;
133+
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
134+
multi_modal_runner_ = std::make_unique<LlavaRunner>(
135+
model_path->toStdString().c_str(),
136+
tokenizer_path->toStdString().c_str(),
137+
temperature);
138+
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
139+
runner_ = std::make_unique<Runner>(
140+
model_path->toStdString().c_str(),
141+
tokenizer_path->toStdString().c_str(),
142+
temperature);
143+
}
126144
}
127145

128146
jint generate(
147+
facebook::jni::alias_ref<jintArray> image,
148+
jint width,
149+
jint height,
150+
jint channels,
129151
facebook::jni::alias_ref<jstring> prompt,
130152
jint seq_len,
131153
facebook::jni::alias_ref<ExecuTorchLlamaCallbackJni> callback) {
132-
runner_->generate(
133-
prompt->toStdString(),
134-
seq_len,
135-
[callback](std::string result) { callback->onResult(result); },
136-
[callback](const Stats& result) { callback->onStats(result); });
154+
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
155+
auto image_size = image->size();
156+
std::vector<Image> images;
157+
if (image_size != 0) {
158+
std::vector<jint> image_data_jint(image_size);
159+
std::vector<uint8_t> image_data(image_size);
160+
image->getRegion(0, image_size, image_data_jint.data());
161+
for (int i = 0; i < image_size; i++) {
162+
image_data[i] = image_data_jint[i];
163+
}
164+
Image image_runner{image_data, width, height, channels};
165+
images.push_back(image_runner);
166+
}
167+
multi_modal_runner_->generate(
168+
images,
169+
prompt->toStdString(),
170+
seq_len,
171+
[callback](std::string result) { callback->onResult(result); },
172+
[callback](const Stats& result) { callback->onStats(result); });
173+
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
174+
runner_->generate(
175+
prompt->toStdString(),
176+
seq_len,
177+
[callback](std::string result) { callback->onResult(result); },
178+
[callback](const Stats& result) { callback->onStats(result); });
179+
}
137180
return 0;
138181
}
139182

140183
void stop() {
141-
runner_->stop();
184+
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
185+
multi_modal_runner_->stop();
186+
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
187+
runner_->stop();
188+
}
142189
}
143190

144191
jint load() {
145-
return static_cast<jint>(runner_->load());
192+
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
193+
return static_cast<jint>(multi_modal_runner_->load());
194+
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
195+
return static_cast<jint>(runner_->load());
196+
}
197+
return static_cast<jint>(Error::InvalidArgument);
146198
}
147199

148200
static void registerNatives() {

extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import com.facebook.soloader.nativeloader.SystemDelegate;
1515

1616
public class LlamaModule {
17+
18+
public static final int MODEL_TYPE_TEXT = 1;
19+
public static final int MODEL_TYPE_TEXT_VISION = 2;
20+
1721
static {
1822
if (!NativeLoader.isInitialized()) {
1923
NativeLoader.init(new SystemDelegate());
@@ -26,11 +30,16 @@ public class LlamaModule {
2630

2731
@DoNotStrip
2832
private static native HybridData initHybrid(
29-
String modulePath, String tokenizerPath, float temperature);
33+
int modelType, String modulePath, String tokenizerPath, float temperature);
3034

3135
/** Constructs a LLAMA Module for a model with given path, tokenizer, and temperature. */
3236
public LlamaModule(String modulePath, String tokenizerPath, float temperature) {
33-
mHybridData = initHybrid(modulePath, tokenizerPath, temperature);
37+
mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature);
38+
}
39+
40+
/** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */
41+
public LlamaModule(int modelType, String modulePath, String tokenizerPath, float temperature) {
42+
mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature);
3443
}
3544

3645
public void resetNative() {
@@ -54,8 +63,30 @@ public int generate(String prompt, LlamaCallback llamaCallback) {
5463
* @param seqLen sequence length
5564
* @param llamaCallback callback object to receive results.
5665
*/
66+
public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) {
67+
return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback);
68+
}
69+
70+
/**
71+
* Start generating tokens from the module.
72+
*
73+
* @param image Input image as a byte array
74+
* @param width Input image width
75+
* @param height Input image height
76+
* @param channels Input image number of channels
77+
* @param prompt Input prompt
78+
* @param seqLen sequence length
79+
* @param llamaCallback callback object to receive results.
80+
*/
5781
@DoNotStrip
58-
public native int generate(String prompt, int seqLen, LlamaCallback llamaCallback);
82+
public native int generate(
83+
int[] image,
84+
int width,
85+
int height,
86+
int channels,
87+
String prompt,
88+
int seqLen,
89+
LlamaCallback llamaCallback);
5990

6091
/** Stop current generate() before it finishes. */
6192
@DoNotStrip

extension/llm/runner/image_prefiller.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class ImagePrefiller {
3232
virtual Error load() = 0;
3333
virtual bool is_method_loaded() = 0;
3434

35+
virtual ~ImagePrefiller() = default;
36+
3537
protected:
3638
Module* module_;
3739
};

extension/llm/runner/multimodal_runner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class MultimodalRunner {
6565
text_token_generator_->stop();
6666
}
6767

68+
virtual ~MultimodalRunner() = default;
69+
6870
protected:
6971
// metadata
7072
int32_t vocab_size_;

extension/llm/runner/targets.bzl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ def define_common_targets():
5959
],
6060
)
6161

62+
runtime.cxx_library(
63+
name = "image_prefiller" + aten_suffix,
64+
exported_headers = ["image_prefiller.h", "image.h"],
65+
visibility = [
66+
"@EXECUTORCH_CLIENTS",
67+
],
68+
exported_deps = [
69+
"//executorch/extension/module:module" + aten_suffix,
70+
],
71+
)
72+
6273
runtime.cxx_library(
6374
name = "metadata_util" + aten_suffix,
6475
exported_headers = ["metadata_util.h"],
@@ -73,14 +84,13 @@ def define_common_targets():
7384
runtime.cxx_library(
7485
name = "runner_lib" + aten_suffix,
7586
exported_headers = [
76-
"image_prefiller.h",
77-
"image.h",
7887
"multimodal_runner.h",
7988
],
8089
visibility = [
8190
"@EXECUTORCH_CLIENTS",
8291
],
8392
exported_deps = [
93+
":image_prefiller" + aten_suffix,
8494
":text_decoder_runner" + aten_suffix,
8595
":text_prefiller" + aten_suffix,
8696
":text_token_generator" + aten_suffix,

0 commit comments

Comments
 (0)