Skip to content

Commit baa5552

Browse files
authored
Android demo app update (#93)
Add gemma 3 support. Need some handling for stop token, image type. Need to figure out image resizing next. Also add some skeleton for audio input button
1 parent f870a2a commit baa5552

File tree

4 files changed

+50
-19
lines changed

4 files changed

+50
-19
lines changed

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall
6161
private ImageButton mSendButton;
6262
private ImageButton mGalleryButton;
6363
private ImageButton mCameraButton;
64+
private ImageButton mAudioButton;
6465
private ListView mMessagesView;
6566
private MessageAdapter mMessageAdapter;
6667
private LlmModule mModule = null;
@@ -81,30 +82,36 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall
8182
private Runnable memoryUpdater;
8283
private boolean mThinkMode = false;
8384
private int promptID = 0;
84-
private long startPos = 0;
85-
private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2;
8685
private Executor executor;
8786
private boolean sawStartHeaderId = false;
8887

8988
@Override
9089
public void onResult(String result) {
9190
if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) {
91+
// For gemma and llava, we need to call stop() explicitly
92+
if (mCurrentSettingsFields.getModelType() == ModelType.GEMMA_3
93+
|| mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
94+
mModule.stop();
95+
}
9296
return;
9397
}
9498
result = PromptFormat.replaceSpecialToken(mCurrentSettingsFields.getModelType(), result);
9599

96-
if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_3 && result.equals("<|start_header_id|>")) {
100+
if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_3
101+
&& result.equals("<|start_header_id|>")) {
97102
sawStartHeaderId = true;
98103
}
99-
if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_3 && result.equals("<|end_header_id|>")) {
104+
if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_3
105+
&& result.equals("<|end_header_id|>")) {
100106
sawStartHeaderId = false;
101107
return;
102108
}
103109
if (sawStartHeaderId) {
104110
return;
105111
}
106112

107-
boolean keepResult = !(result.equals("\n") || result.equals("\n\n")) || !mResultMessage.getText().isEmpty();
113+
boolean keepResult =
114+
!(result.equals("\n") || result.equals("\n\n")) || !mResultMessage.getText().isEmpty();
108115
if (keepResult) {
109116
mResultMessage.appendText(result);
110117
run();
@@ -466,6 +473,11 @@ private void setupMediaButton() {
466473
.setMediaType(ActivityResultContracts.PickVisualMedia.ImageOnly.INSTANCE)
467474
.build());
468475
});
476+
mAudioButton = requireViewById(R.id.audioButton);
477+
mAudioButton.setOnClickListener(
478+
view -> {
479+
mAddMediaLayout.setVisibility(View.GONE);
480+
});
469481
mCameraButton = requireViewById(R.id.cameraButton);
470482
mCameraButton.setOnClickListener(
471483
view -> {
@@ -661,7 +673,8 @@ private void showMediaPreview(List<Uri> uris) {
661673

662674
// For LLava, we want to call prefill_image as soon as an image is selected
663675
// Llava only support 1 image for now
664-
if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5 || mCurrentSettingsFields.getModelType() == ModelType.GEMMA_3) {
676+
if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5
677+
|| mCurrentSettingsFields.getModelType() == ModelType.GEMMA_3) {
665678
List<ETImage> processedImageList = getProcessedImagesForModel(mSelectedImageUri);
666679
if (!processedImageList.isEmpty()) {
667680
mMessageAdapter.add(
@@ -673,12 +686,19 @@ private void showMediaPreview(List<Uri> uris) {
673686
ETLogging.getInstance().log("Starting runnable prefill image");
674687
ETImage img = processedImageList.get(0);
675688
ETLogging.getInstance().log("Llava start prefill image");
676-
startPos =
677-
mModule.prefillImages(
678-
img.getInts(),
679-
img.getWidth(),
680-
img.getHeight(),
681-
ModelUtils.VISION_MODEL_IMAGE_CHANNELS);
689+
if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) {
690+
mModule.prefillImages(
691+
img.getInts(),
692+
img.getWidth(),
693+
img.getHeight(),
694+
ModelUtils.VISION_MODEL_IMAGE_CHANNELS);
695+
} else if (mCurrentSettingsFields.getModelType() == ModelType.GEMMA_3) {
696+
mModule.prefillImages(
697+
img.getFloats(),
698+
img.getWidth(),
699+
img.getHeight(),
700+
ModelUtils.VISION_MODEL_IMAGE_CHANNELS);
701+
}
682702
};
683703
executor.execute(runnable);
684704
}
@@ -722,7 +742,6 @@ private void onModelRunStopped() {
722742
String rawPrompt = mEditTextMessage.getText().toString();
723743
String finalPrompt =
724744
mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode);
725-
mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt, mThinkMode);
726745
// We store raw prompt into message adapter, because we don't want to show the extra
727746
// tokens from system prompt
728747
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID));
@@ -769,10 +788,7 @@ public void run() {
769788
} else {
770789
ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt);
771790
mModule.generate(
772-
finalPrompt,
773-
(int) (finalPrompt.length() * 0.75) + 64,
774-
MainActivity.this,
775-
false);
791+
finalPrompt, ModelUtils.TEXT_MODEL_SEQ_LEN, MainActivity.this, false);
776792
}
777793

778794
long generateDuration = System.currentTimeMillis() - generateStartTime;

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ public class ModelUtils {
1515
// XNNPACK or Vulkan
1616
static final int VISION_MODEL = 2;
1717
static final int VISION_MODEL_IMAGE_CHANNELS = 3;
18-
static final int VISION_MODEL_SEQ_LEN = 768;
19-
static final int TEXT_MODEL_SEQ_LEN = 256;
18+
static final int VISION_MODEL_SEQ_LEN = 2048;
19+
static final int TEXT_MODEL_SEQ_LEN = 768;
2020

2121
// MediaTek
2222
static final int MEDIATEK_TEXT_MODEL = 3;
@@ -29,6 +29,7 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
2929
switch (modelType) {
3030
case GEMMA_3:
3131
case LLAVA_1_5:
32+
case VOXTRAL:
3233
return VISION_MODEL;
3334
case LLAMA_3:
3435
case QWEN_3:
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<vector xmlns:android="http://schemas.android.com/apk/res/android" android:height="48dp" android:tint="#000000" android:viewportHeight="24" android:viewportWidth="24" android:width="48dp">
2+
3+
<path android:fillColor="@android:color/white" android:pathData="M14,2H6C4.9,2 4.01,2.9 4.01,4L4,20c0,1.1 0.89,2 1.99,2H18c1.1,0 2,-0.9 2,-2V8L14,2zM16,13h-3v3.75c0,1.24 -1.01,2.25 -2.25,2.25S8.5,17.99 8.5,16.75c0,-1.24 1.01,-2.25 2.25,-2.25c0.46,0 0.89,0.14 1.25,0.38V11h4V13zM13,9V3.5L18.5,9H13z"/>
4+
5+
</vector>

llm/android/LlamaDemo/app/src/main/res/layout/activity_main.xml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,15 @@
234234
android:layout_marginStart="40dp"
235235
android:background="@drawable/custom_button_round"
236236
android:src="@drawable/outline_image_48" />
237+
238+
<ImageButton
239+
android:id="@+id/audioButton"
240+
android:layout_width="80dp"
241+
android:layout_height="80dp"
242+
android:layout_marginStart="40dp"
243+
android:background="@drawable/custom_button_round"
244+
android:src="@drawable/baseline_audio_file_48" />
245+
237246
</LinearLayout>
238247
</LinearLayout>
239248

0 commit comments

Comments
 (0)