Skip to content

Commit 487d584

Browse files
Android LlmModule add API for normalized image input (#15145)
### Summary Some models require normalized image, as float[] ### Test plan CI Co-authored-by: Hansong Zhang <[email protected]>
1 parent 39c09b1 commit 487d584

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,29 @@ public long prefillImages(int[] image, int width, int height, int channels) {
188188

189189
private native int appendImagesInput(int[] image, int width, int height, int channels);
190190

191+
/**
192+
* Prefill an LLaVA Module with the given images input.
193+
*
194+
* @param image Input normalized image as a float array
195+
* @param width Input image width
196+
* @param height Input image height
197+
* @param channels Input image number of channels
198+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
199+
* exposed to user.
200+
* @throws RuntimeException if the prefill failed
201+
*/
202+
@Deprecated
203+
public long prefillImages(float[] image, int width, int height, int channels) {
204+
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
205+
if (nativeResult != 0) {
206+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
207+
}
208+
return 0;
209+
}
210+
211+
private native int appendNormalizedImagesInput(
212+
float[] image, int width, int height, int channels);
213+
191214
/**
192215
* Prefill an LLaVA Module with the given text input.
193216
*

extension/android/jni/jni_layer_llama.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,32 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
276276
return 0;
277277
}
278278

279+
// Returns status_code
280+
jint append_normalized_images_input(
281+
facebook::jni::alias_ref<jfloatArray> image,
282+
jint width,
283+
jint height,
284+
jint channels) {
285+
std::vector<llm::Image> images;
286+
if (image == nullptr) {
287+
return static_cast<jint>(Error::EndOfMethod);
288+
}
289+
auto image_size = image->size();
290+
if (image_size != 0) {
291+
std::vector<jfloat> image_data_jfloat(image_size);
292+
std::vector<float> image_data(image_size);
293+
image->getRegion(0, image_size, image_data_jfloat.data());
294+
for (int i = 0; i < image_size; i++) {
295+
image_data[i] = image_data_jfloat[i];
296+
}
297+
llm::Image image_runner{std::move(image_data), width, height, channels};
298+
prefill_inputs_.emplace_back(
299+
llm::MultimodalInput{std::move(image_runner)});
300+
}
301+
302+
return 0;
303+
}
304+
279305
void stop() {
280306
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
281307
multi_modal_runner_->stop();
@@ -310,6 +336,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
310336
makeNativeMethod("load", ExecuTorchLlmJni::load),
311337
makeNativeMethod(
312338
"appendImagesInput", ExecuTorchLlmJni::append_images_input),
339+
makeNativeMethod(
340+
"appendNormalizedImagesInput",
341+
ExecuTorchLlmJni::append_normalized_images_input),
313342
makeNativeMethod(
314343
"appendTextInput", ExecuTorchLlmJni::append_text_input),
315344
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),

0 commit comments

Comments
 (0)