Skip to content

Commit 53b7ec5

Browse files
authored
Android Refactor generate() (#14183)
### Summary We will no longer pass startPos, bos, eos args. User doesn't need to know that. For generate_from_pos with image and text prompts, it will add input prompts first, and then invoke generate() ### Test plan CI
1 parent de30390 commit 53b7ec5

File tree

3 files changed

+27
-53
lines changed

3 files changed

+27
-53
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall
8181
private Runnable memoryUpdater;
8282
private boolean mThinkMode = false;
8383
private int promptID = 0;
84-
private long startPos = 0;
8584
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
8685
private Executor executor;
8786

@@ -178,7 +177,8 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera
178177

179178
if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
180179
ETLogging.getInstance().log("Llava start prefill prompt");
181-
startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0);
180+
mModule.resetContext();
181+
mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt());
182182
ETLogging.getInstance().log("Llava completes prefill prompt");
183183
}
184184
}
@@ -645,13 +645,11 @@ private void showMediaPreview(List<Uri> uris) {
645645
ETLogging.getInstance().log("Starting runnable prefill image");
646646
ETImage img = processedImageList.get(0);
647647
ETLogging.getInstance().log("Llava start prefill image");
648-
startPos =
649-
mModule.prefillImages(
650-
img.getInts(),
651-
img.getWidth(),
652-
img.getHeight(),
653-
ModelUtils.VISION_MODEL_IMAGE_CHANNELS,
654-
startPos);
648+
mModule.prefillImages(
649+
img.getInts(),
650+
img.getWidth(),
651+
img.getHeight(),
652+
ModelUtils.VISION_MODEL_IMAGE_CHANNELS);
655653
};
656654
executor.execute(runnable);
657655
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,7 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) {
125125
* @param llmCallback callback object to receive results
126126
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
127127
*/
128-
public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) {
129-
return generate(null, 0, 0, 0, prompt, seqLen, llmCallback, echo);
130-
}
128+
public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo);
131129

132130
/**
133131
* Start generating tokens from the module.
@@ -154,16 +152,19 @@ public int generate(String prompt, LlmGenerationConfig config, LlmCallback llmCa
154152
* @param llmCallback callback object to receive results.
155153
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
156154
*/
157-
@DoNotStrip
158-
public native int generate(
155+
public int generate(
159156
int[] image,
160157
int width,
161158
int height,
162159
int channels,
163160
String prompt,
164161
int seqLen,
165162
LlmCallback llmCallback,
166-
boolean echo);
163+
boolean echo) {
164+
prefillPrompt(prompt);
165+
prefillImages(image, width, height, channels);
166+
return generate("", llmCallback, echo);
167+
}
167168

168169
/**
169170
* Prefill an LLaVA Module with the given images input.
@@ -172,16 +173,12 @@ public native int generate(
172173
* @param width Input image width
173174
* @param height Input image height
174175
* @param channels Input image number of channels
175-
* @param startPos The starting position in KV cache of the input in the LLM.
176176
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
177177
* exposed to user.
178178
* @throws RuntimeException if the prefill failed
179179
*/
180180
@Deprecated
181-
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
182-
if (startPos == 0) {
183-
resetContext();
184-
}
181+
public long prefillImages(int[] image, int width, int height, int channels) {
185182
int nativeResult = appendImagesInput(image, width, height, channels);
186183
if (nativeResult != 0) {
187184
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
@@ -195,28 +192,21 @@ public long prefillImages(int[] image, int width, int height, int channels, long
195192
* Prefill an LLaVA Module with the given text input.
196193
*
197194
* @param prompt The text prompt to LLaVA.
198-
* @param startPos The starting position in KV cache of the input in the LLM. It's passed as
199-
* reference and will be updated inside this function.
200-
* @param bos The number of BOS (begin of sequence) token.
201-
* @param eos The number of EOS (end of sequence) token.
202195
* @return 0, as the updated starting position in KV cache of the input in the LLM is no longer
203196
* exposed to user.
204197
* @throws RuntimeException if the prefill failed
205198
*/
206199
@Deprecated
207-
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
208-
if (startPos == 0) {
209-
resetContext();
210-
}
211-
int nativeResult = appendTextInput(prompt, bos, eos);
200+
public long prefillPrompt(String prompt) {
201+
int nativeResult = appendTextInput(prompt);
212202
if (nativeResult != 0) {
213203
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
214204
}
215205
return 0;
216206
}
217207

218-
// returns a tuple of (status, updated startPos)
219-
private native int appendTextInput(String prompt, int bos, int eos);
208+
// returns status
209+
private native int appendTextInput(String prompt);
220210

221211
/**
222212
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.

extension/android/jni/jni_layer_llama.cpp

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -208,29 +208,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
208208
}
209209

210210
jint generate(
211-
facebook::jni::alias_ref<jintArray> image,
212-
jint width,
213-
jint height,
214-
jint channels,
215211
facebook::jni::alias_ref<jstring> prompt,
216212
jint seq_len,
217213
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
218214
jboolean echo) {
219215
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
220216
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
221217
prefill_inputs_.clear();
222-
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
223-
auto image_size = image->size();
224-
std::vector<llm::Image> images;
225-
if (image_size != 0) {
226-
std::vector<jint> image_data_jint(image_size);
227-
std::vector<uint8_t> image_data(image_size);
228-
image->getRegion(0, image_size, image_data_jint.data());
229-
for (int i = 0; i < image_size; i++) {
230-
image_data[i] = image_data_jint[i];
231-
}
232-
llm::Image image_runner{image_data, width, height, channels};
233-
inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)});
218+
if (!prompt->toStdString().empty()) {
219+
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
234220
}
235221
executorch::extension::llm::GenerationConfig config{
236222
.echo = static_cast<bool>(echo),
@@ -257,23 +243,23 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
257243
return 0;
258244
}
259245

260-
// Returns a tuple of (error, start_pos)
246+
// Returns status_code
261247
// Contract is valid within an AAR (JNI + corresponding Java code)
262-
// If the first element is not Error::Ok, the other element is undefined.
263-
jint append_text_input(
264-
facebook::jni::alias_ref<jstring> prompt,
265-
jint bos,
266-
jint eos) {
248+
jint append_text_input(facebook::jni::alias_ref<jstring> prompt) {
267249
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
268250
return 0;
269251
}
270252

253+
// Returns status_code
271254
jint append_images_input(
272255
facebook::jni::alias_ref<jintArray> image,
273256
jint width,
274257
jint height,
275258
jint channels) {
276259
std::vector<llm::Image> images;
260+
if (image == nullptr) {
261+
return static_cast<jint>(Error::EndOfMethod);
262+
}
277263
auto image_size = image->size();
278264
if (image_size != 0) {
279265
std::vector<jint> image_data_jint(image_size);

0 commit comments

Comments
 (0)