Skip to content

Commit 6d559c4

Browse files
committed
Add API for normalized image input
1 parent 57a7903 commit 6d559c4

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,28 @@ public long prefillImages(int[] image, int width, int height, int channels) {
207207

208208
private native int appendImagesInput(int[] image, int width, int height, int channels);
209209

210+
/**
211+
* Prefill an LLaVA Module with the given images input.
212+
*
213+
* @param image Input normalized image as a float array
214+
* @param width Input image width
215+
* @param height Input image height
216+
* @param channels Input image number of channels
217+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
218+
* exposed to user.
219+
* @throws RuntimeException if the prefill failed
220+
*/
221+
@Deprecated
222+
public long prefillImages(float[] image, int width, int height, int channels) {
223+
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
224+
if (nativeResult != 0) {
225+
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
226+
}
227+
return 0;
228+
}
229+
230+
private native int appendNormalizedImagesInput(float[] image, int width, int height, int channels);
231+
210232
/**
211233
* Prefill an LLaVA Module with the given text input.
212234
*

extension/android/jni/jni_layer_llama.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,32 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
289289
return 0;
290290
}
291291

292+
// Returns status_code
293+
jint append_normalized_images_input(
294+
facebook::jni::alias_ref<jfloatArray> image,
295+
jint width,
296+
jint height,
297+
jint channels) {
298+
std::vector<llm::Image> images;
299+
if (image == nullptr) {
300+
return static_cast<jint>(Error::EndOfMethod);
301+
}
302+
auto image_size = image->size();
303+
if (image_size != 0) {
304+
std::vector<jfloat> image_data_jfloat(image_size);
305+
std::vector<float> image_data(image_size);
306+
image->getRegion(0, image_size, image_data_jfloat.data());
307+
for (int i = 0; i < image_size; i++) {
308+
image_data[i] = image_data_jfloat[i];
309+
}
310+
llm::Image image_runner{std::move(image_data), width, height, channels};
311+
prefill_inputs_.emplace_back(
312+
llm::MultimodalInput{std::move(image_runner)});
313+
}
314+
315+
return 0;
316+
}
317+
292318
void stop() {
293319
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
294320
multi_modal_runner_->stop();
@@ -323,6 +349,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
323349
makeNativeMethod("load", ExecuTorchLlmJni::load),
324350
makeNativeMethod(
325351
"appendImagesInput", ExecuTorchLlmJni::append_images_input),
352+
makeNativeMethod(
353+
"appendNormalizedImagesInput",
354+
ExecuTorchLlmJni::append_normalized_images_input),
326355
makeNativeMethod(
327356
"appendTextInput", ExecuTorchLlmJni::append_text_input),
328357
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),

0 commit comments

Comments
 (0)