Skip to content

Commit a8cbd3a

Browse files
committed
Update Android API
1 parent 026693d commit a8cbd3a

File tree

2 files changed

+19
-39
lines changed

2 files changed

+19
-39
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,23 +173,23 @@ public native int generate(
173173
* @param height Input image height
174174
* @param channels Input image number of channels
175175
* @param startPos The starting position in KV cache of the input in the LLM.
176-
* @return The updated starting position in KV cache of the input in the LLM.
176+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
177+
* exposed to user.
177178
* @throws RuntimeException if the prefill failed
178179
*/
179180
@Deprecated
180181
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
181182
if (startPos == 0) {
182183
resetContext();
183184
}
184-
long[] nativeResult = prefillImagesNative(image, width, height, channels);
185-
if (nativeResult[0] != 0) {
185+
int nativeResult = prefillImagesNative(image, width, height, channels);
186+
if (nativeResult != 0) {
186187
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
187188
}
188-
return nativeResult[1];
189+
return 0;
189190
}
190191

191-
// returns a tuple of (status, updated startPos)
192-
private native long[] prefillImagesNative(int[] image, int width, int height, int channels);
192+
private native int prefillImagesNative(int[] image, int width, int height, int channels);
193193

194194
/**
195195
* Prefill an LLaVA Module with the given text input.
@@ -199,23 +199,24 @@ public long prefillImages(int[] image, int width, int height, int channels, long
199199
* reference and will be updated inside this function.
200200
* @param bos The number of BOS (begin of sequence) token.
201201
* @param eos The number of EOS (end of sequence) token.
202-
* @return The updated starting position in KV cache of the input in the LLM.
202+
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
203+
* exposed to user.
203204
* @throws RuntimeException if the prefill failed
204205
*/
205206
@Deprecated
206207
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
207208
if (startPos == 0) {
208209
resetContext();
209210
}
210-
long[] nativeResult = prefillPromptNative(prompt, bos, eos);
211-
if (nativeResult[0] != 0) {
211+
int nativeResult = prefillPromptNative(prompt, bos, eos);
212+
if (nativeResult != 0) {
212213
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
213214
}
214-
return nativeResult[1];
215+
return 0;
215216
}
216217

217218
// returns a tuple of (status, updated startPos)
218-
private native long[] prefillPromptNative(String prompt, int bos, int eos);
219+
private native int prefillPromptNative(String prompt, int bos, int eos);
219220

220221
/**
221222
* Generate tokens from the given prompt, starting from the given position.
@@ -238,7 +239,7 @@ public native int generateFromPos(
238239
*
239240
* <p>The startPos will be reset to 0.
240241
*/
241-
public native int resetContext();
242+
public native void resetContext();
242243

243244
/** Stop current generate() before it finishes. */
244245
@DoNotStrip

extension/android/jni/jni_layer_llama.cpp

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -251,42 +251,24 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
251251
return 0;
252252
}
253253

254-
// Returns a tuple of (error, start_pos)
255-
// Contract is valid within an AAR (JNI + corresponding Java code)
256-
// If the first element is not Error::Ok, the other element is undefined.
257-
facebook::jni::local_ref<jlongArray>
254+
jint
258255
prefill_prompt(facebook::jni::alias_ref<jstring> prompt, jint bos, jint eos) {
259-
facebook::jni::local_ref<jlongArray> tuple_result =
260-
facebook::jni::make_long_array(2);
261256
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
262-
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
263-
return tuple_result;
257+
return static_cast<jint>(Error::NotSupported);
264258
}
265259

266260
auto&& result =
267261
multi_modal_runner_->prefill_prompt(prompt->toStdString(), bos, eos);
268-
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
269-
if (result.ok()) {
270-
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
271-
}
272-
return tuple_result;
262+
return static_cast<jint>(result.error());
273263
}
274264

275-
// Returns a tuple of (error, start_pos)
276-
// Contract is valid within an AAR (JNI + corresponding Java code)
277-
// If the first element is not Error::Ok, the other element is undefined.
278-
279-
facebook::jni::local_ref<jlongArray> prefill_images(
265+
jint prefill_images(
280266
facebook::jni::alias_ref<jintArray> image,
281267
jint width,
282268
jint height,
283269
jint channels) {
284-
facebook::jni::local_ref<jlongArray> tuple_result =
285-
facebook::jni::make_long_array(2);
286-
287270
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
288-
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
289-
return tuple_result;
271+
return static_cast<jint>(Error::NotSupported);
290272
}
291273

292274
auto image_size = image->size();
@@ -301,12 +283,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
301283
llm::Image image_runner{image_data, width, height, channels};
302284
images.push_back(image_runner);
303285
}
304-
// TODO(hsz): make start_pos a reference and update it here
305286
jint result =
306287
static_cast<jint>(multi_modal_runner_->prefill_images(images));
307-
tuple_result->pin()[0] = result;
308-
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
309-
return tuple_result;
288+
return result;
310289
}
311290

312291
jint generate_from_pos(

0 commit comments

Comments
 (0)