@@ -61,6 +61,7 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall
6161 private ImageButton mSendButton ;
6262 private ImageButton mGalleryButton ;
6363 private ImageButton mCameraButton ;
64+ private ImageButton mAudioButton ;
6465 private ListView mMessagesView ;
6566 private MessageAdapter mMessageAdapter ;
6667 private LlmModule mModule = null ;
@@ -81,30 +82,36 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlmCall
8182 private Runnable memoryUpdater ;
8283 private boolean mThinkMode = false ;
8384 private int promptID = 0 ;
84- private long startPos = 0 ;
85- private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2 ;
8685 private Executor executor ;
8786 private boolean sawStartHeaderId = false ;
8887
8988 @ Override
9089 public void onResult (String result ) {
9190 if (result .equals (PromptFormat .getStopToken (mCurrentSettingsFields .getModelType ()))) {
91+ // For gemma and llava, we need to call stop() explicitly
92+ if (mCurrentSettingsFields .getModelType () == ModelType .GEMMA_3
93+ || mCurrentSettingsFields .getModelType () == ModelType .LLAVA_1_5 ) {
94+ mModule .stop ();
95+ }
9296 return ;
9397 }
9498 result = PromptFormat .replaceSpecialToken (mCurrentSettingsFields .getModelType (), result );
9599
96- if (mCurrentSettingsFields .getModelType () == ModelType .LLAMA_3 && result .equals ("<|start_header_id|>" )) {
100+ if (mCurrentSettingsFields .getModelType () == ModelType .LLAMA_3
101+ && result .equals ("<|start_header_id|>" )) {
97102 sawStartHeaderId = true ;
98103 }
99- if (mCurrentSettingsFields .getModelType () == ModelType .LLAMA_3 && result .equals ("<|end_header_id|>" )) {
104+ if (mCurrentSettingsFields .getModelType () == ModelType .LLAMA_3
105+ && result .equals ("<|end_header_id|>" )) {
100106 sawStartHeaderId = false ;
101107 return ;
102108 }
103109 if (sawStartHeaderId ) {
104110 return ;
105111 }
106112
107- boolean keepResult = !(result .equals ("\n " ) || result .equals ("\n \n " )) || !mResultMessage .getText ().isEmpty ();
113+ boolean keepResult =
114+ !(result .equals ("\n " ) || result .equals ("\n \n " )) || !mResultMessage .getText ().isEmpty ();
108115 if (keepResult ) {
109116 mResultMessage .appendText (result );
110117 run ();
@@ -466,6 +473,11 @@ private void setupMediaButton() {
466473 .setMediaType (ActivityResultContracts .PickVisualMedia .ImageOnly .INSTANCE )
467474 .build ());
468475 });
476+ mAudioButton = requireViewById (R .id .audioButton );
477+ mAudioButton .setOnClickListener (
478+ view -> {
479+ mAddMediaLayout .setVisibility (View .GONE );
480+ });
469481 mCameraButton = requireViewById (R .id .cameraButton );
470482 mCameraButton .setOnClickListener (
471483 view -> {
@@ -661,7 +673,8 @@ private void showMediaPreview(List<Uri> uris) {
661673
662674 // For LLava, we want to call prefill_image as soon as an image is selected
663675 // Llava only support 1 image for now
664- if (mCurrentSettingsFields .getModelType () == ModelType .LLAVA_1_5 || mCurrentSettingsFields .getModelType () == ModelType .GEMMA_3 ) {
676+ if (mCurrentSettingsFields .getModelType () == ModelType .LLAVA_1_5
677+ || mCurrentSettingsFields .getModelType () == ModelType .GEMMA_3 ) {
665678 List <ETImage > processedImageList = getProcessedImagesForModel (mSelectedImageUri );
666679 if (!processedImageList .isEmpty ()) {
667680 mMessageAdapter .add (
@@ -673,12 +686,19 @@ private void showMediaPreview(List<Uri> uris) {
673686 ETLogging .getInstance ().log ("Starting runnable prefill image" );
674687 ETImage img = processedImageList .get (0 );
675688 ETLogging .getInstance ().log ("Llava start prefill image" );
676- startPos =
677- mModule .prefillImages (
678- img .getInts (),
679- img .getWidth (),
680- img .getHeight (),
681- ModelUtils .VISION_MODEL_IMAGE_CHANNELS );
689+ if (mCurrentSettingsFields .getModelType () == ModelType .LLAVA_1_5 ) {
690+ mModule .prefillImages (
691+ img .getInts (),
692+ img .getWidth (),
693+ img .getHeight (),
694+ ModelUtils .VISION_MODEL_IMAGE_CHANNELS );
695+ } else if (mCurrentSettingsFields .getModelType () == ModelType .GEMMA_3 ) {
696+ mModule .prefillImages (
697+ img .getFloats (),
698+ img .getWidth (),
699+ img .getHeight (),
700+ ModelUtils .VISION_MODEL_IMAGE_CHANNELS );
701+ }
682702 };
683703 executor .execute (runnable );
684704 }
@@ -722,7 +742,6 @@ private void onModelRunStopped() {
722742 String rawPrompt = mEditTextMessage .getText ().toString ();
723743 String finalPrompt =
724744 mCurrentSettingsFields .getFormattedSystemAndUserPrompt (rawPrompt , mThinkMode );
725- mCurrentSettingsFields .getFormattedSystemAndUserPrompt (rawPrompt , mThinkMode );
726745 // We store raw prompt into message adapter, because we don't want to show the extra
727746 // tokens from system prompt
728747 mMessageAdapter .add (new Message (rawPrompt , true , MessageType .TEXT , promptID ));
@@ -769,10 +788,7 @@ public void run() {
769788 } else {
770789 ETLogging .getInstance ().log ("Running inference.. prompt=" + finalPrompt );
771790 mModule .generate (
772- finalPrompt ,
773- (int ) (finalPrompt .length () * 0.75 ) + 64 ,
774- MainActivity .this ,
775- false );
791+ finalPrompt , ModelUtils .TEXT_MODEL_SEQ_LEN , MainActivity .this , false );
776792 }
777793
778794 long generateDuration = System .currentTimeMillis () - generateStartTime ;
0 commit comments