Skip to content

Commit 8c813f9

Browse files
authored
Simplify prompt fields logic and added modelType selector
Differential Revision: D60840351 Pull Request resolved: #4581
1 parent 530d4a1 commit 8c813f9

File tree

6 files changed

+194
-79
lines changed

6 files changed

+194
-79
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import android.provider.MediaStore;
2323
import android.system.ErrnoException;
2424
import android.system.Os;
25-
import android.text.InputType;
2625
import android.util.Log;
2726
import android.view.View;
2827
import android.widget.EditText;
@@ -71,9 +70,6 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa
7170
private SettingsFields mCurrentSettingsFields;
7271
private Handler mMemoryUpdateHandler;
7372
private Runnable memoryUpdater;
74-
// UI Specific to user using INSTRUCT_MODE
75-
private boolean INSTRUCT_MODE = false;
76-
private String INSTRUCT_INSTRUCTION = "In Instruct Mode. Press SEND";
7773

7874
@Override
7975
public void onResult(String result) {
@@ -253,7 +249,6 @@ protected void onResume() {
253249
} else {
254250
askUserToSelectModel();
255251
}
256-
checkForPromptChange(updatedSettingsFields);
257252
checkForClearChatHistory(updatedSettingsFields);
258253
// Update current to point to the latest
259254
mCurrentSettingsFields = new SettingsFields(updatedSettingsFields);
@@ -296,29 +291,6 @@ private void checkForUpdateAndReloadModel(SettingsFields updatedSettingsFields)
296291
}
297292
}
298293

299-
private void checkForPromptChange(SettingsFields updatedSettingsFields) {
300-
if (updatedSettingsFields.isSystemPromptChanged()
301-
|| updatedSettingsFields.isUserPromptChanged()) {
302-
enableInstructMode();
303-
} else {
304-
disableInstructMode();
305-
}
306-
}
307-
308-
private void enableInstructMode() {
309-
INSTRUCT_MODE = true;
310-
mEditTextMessage.setText(INSTRUCT_INSTRUCTION);
311-
mEditTextMessage.setInputType(InputType.TYPE_NULL);
312-
mEditTextMessage.clearFocus();
313-
}
314-
315-
private void disableInstructMode() {
316-
INSTRUCT_MODE = false;
317-
mEditTextMessage.setText("");
318-
mEditTextMessage.setInputType(InputType.TYPE_CLASS_TEXT);
319-
mEditTextMessage.clearFocus();
320-
}
321-
322294
private void askUserToSelectModel() {
323295
String askLoadModel =
324296
"To get started, select your desired model and tokenizer " + "from the top right corner";
@@ -600,15 +572,11 @@ private void onModelRunStopped() {
600572
+ " bytes size = "
601573
+ image.getBytes().length);
602574
});
603-
String prompt;
604-
if (INSTRUCT_MODE) {
605-
prompt = mCurrentSettingsFields.getEntirePrompt();
606-
mEditTextMessage.setText(INSTRUCT_INSTRUCTION);
607-
} else {
608-
prompt = mEditTextMessage.getText().toString();
609-
mEditTextMessage.setText("");
610-
}
611-
mMessageAdapter.add(new Message(prompt, true, MessageType.TEXT, 0));
575+
String rawPrompt = mEditTextMessage.getText().toString();
576+
String prompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
577+
// We store raw prompt into message adapter, because we don't want to show the extra
578+
// tokens from system prompt
579+
mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, 0));
612580
mMessageAdapter.notifyDataSetChanged();
613581
mEditTextMessage.setText("");
614582
mResultMessage = new Message("", false, MessageType.TEXT, 0);
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo;
10+
11+
public enum ModelType {
12+
LLAMA_3,
13+
LLAMA_3_1,
14+
LLAVA_1_5,
15+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo;
10+
11+
public class PromptFormat {
12+
13+
public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}";
14+
public static final String USER_PLACEHOLDER = "{{ user_prompt }}";
15+
16+
public static String getSystemPromptTemplate(ModelType modelType) {
17+
switch (modelType) {
18+
case LLAMA_3:
19+
case LLAMA_3_1:
20+
return "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
21+
+ SYSTEM_PLACEHOLDER
22+
+ "<|eot_id|>";
23+
case LLAVA_1_5:
24+
default:
25+
return SYSTEM_PLACEHOLDER;
26+
}
27+
}
28+
29+
public static String getUserPromptTemplate(ModelType modelType) {
30+
switch (modelType) {
31+
case LLAMA_3:
32+
case LLAMA_3_1:
33+
return "<|start_header_id|>user<|end_header_id|>\n"
34+
+ USER_PLACEHOLDER
35+
+ "<|eot_id|>\n"
36+
+ "<|start_header_id|>assistant<|end_header_id|>";
37+
case LLAVA_1_5:
38+
default:
39+
return USER_PLACEHOLDER;
40+
}
41+
}
42+
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,23 @@
2323
import androidx.core.view.WindowInsetsCompat;
2424
import com.google.gson.Gson;
2525
import java.io.File;
26+
import java.util.ArrayList;
27+
import java.util.List;
2628

2729
public class SettingsActivity extends AppCompatActivity {
2830

2931
private String mModelFilePath = "";
3032
private String mTokenizerFilePath = "";
3133
private TextView mModelTextView;
3234
private TextView mTokenizerTextView;
33-
private ImageButton mModelImageButton;
34-
private ImageButton mTokenizerImageButton;
35+
private TextView mModelTypeTextView;
3536
private EditText mSystemPromptEditText;
3637
private EditText mUserPromptEditText;
3738
private Button mLoadModelButton;
3839
private double mSetTemperature;
3940
private String mSystemPrompt;
4041
private String mUserPrompt;
41-
42+
private ModelType mModelType;
4243
public SettingsFields mSettingsFields;
4344

4445
private DemoSharedPreferences mDemoSharedPreferences;
@@ -63,21 +64,27 @@ protected void onCreate(Bundle savedInstanceState) {
6364
private void setupSettings() {
6465
mModelTextView = requireViewById(R.id.modelTextView);
6566
mTokenizerTextView = requireViewById(R.id.tokenizerTextView);
66-
mModelImageButton = requireViewById(R.id.modelImageButton);
67-
mTokenizerImageButton = requireViewById(R.id.tokenizerImageButton);
67+
mModelTypeTextView = requireViewById(R.id.modelTypeTextView);
68+
ImageButton modelImageButton = requireViewById(R.id.modelImageButton);
69+
ImageButton tokenizerImageButton = requireViewById(R.id.tokenizerImageButton);
70+
ImageButton modelTypeImageButton = requireViewById(R.id.modelTypeImageButton);
6871
mSystemPromptEditText = requireViewById(R.id.systemPromptText);
6972
mUserPromptEditText = requireViewById(R.id.userPromptText);
7073
loadSettings();
7174

7275
// TODO: The two setOnClickListeners will be removed after file path issue is resolved
73-
mModelImageButton.setOnClickListener(
76+
modelImageButton.setOnClickListener(
7477
view -> {
7578
setupModelSelectorDialog();
7679
});
77-
mTokenizerImageButton.setOnClickListener(
80+
tokenizerImageButton.setOnClickListener(
7881
view -> {
7982
setupTokenizerSelectorDialog();
8083
});
84+
modelTypeImageButton.setOnClickListener(
85+
view -> {
86+
setupModelTypeSelectorDialog();
87+
});
8188
mModelFilePath = mSettingsFields.getModelFilePath();
8289
if (!mModelFilePath.isEmpty()) {
8390
mModelTextView.setText(getFilenameFromPath(mModelFilePath));
@@ -86,6 +93,11 @@ private void setupSettings() {
8693
if (!mTokenizerFilePath.isEmpty()) {
8794
mTokenizerTextView.setText(getFilenameFromPath(mTokenizerFilePath));
8895
}
96+
mModelType = mSettingsFields.getModelType();
97+
ETLogging.getInstance().log("mModelType from settings " + mModelType);
98+
if (mModelType != null) {
99+
mModelTypeTextView.setText(mModelType.toString());
100+
}
89101

90102
setupParameterSettings();
91103
setupPromptSettings();
@@ -196,7 +208,8 @@ public void afterTextChanged(Editable s) {
196208
new DialogInterface.OnClickListener() {
197209
public void onClick(DialogInterface dialog, int whichButton) {
198210
// Clear the messageAdapter and sharedPreference
199-
mSystemPromptEditText.setText(mSettingsFields.getSystemPromptTemplate());
211+
mSystemPromptEditText.setText(
212+
PromptFormat.getSystemPromptTemplate(mModelType));
200213
}
201214
})
202215
.setNegativeButton(android.R.string.no, null)
@@ -217,7 +230,11 @@ public void onTextChanged(CharSequence s, int start, int before, int count) {}
217230

218231
@Override
219232
public void afterTextChanged(Editable s) {
220-
mUserPrompt = s.toString();
233+
if (isValidUserPrompt(s.toString())) {
234+
mUserPrompt = s.toString();
235+
} else {
236+
showInvalidPromptDialog();
237+
}
221238
}
222239
});
223240

@@ -233,14 +250,35 @@ public void afterTextChanged(Editable s) {
233250
new DialogInterface.OnClickListener() {
234251
public void onClick(DialogInterface dialog, int whichButton) {
235252
// Clear the messageAdapter and sharedPreference
236-
mUserPromptEditText.setText(mSettingsFields.getUserPromptTemplate());
253+
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
237254
}
238255
})
239256
.setNegativeButton(android.R.string.no, null)
240257
.show();
241258
});
242259
}
243260

261+
private boolean isValidUserPrompt(String userPrompt) {
262+
return userPrompt.contains(PromptFormat.USER_PLACEHOLDER);
263+
}
264+
265+
private void showInvalidPromptDialog() {
266+
new AlertDialog.Builder(this)
267+
.setTitle("Invalid Prompt Format")
268+
.setMessage(
269+
"Prompt format must contain "
270+
+ PromptFormat.USER_PLACEHOLDER
271+
+ ". Do you want to reset prompt format?")
272+
.setIcon(android.R.drawable.ic_dialog_alert)
273+
.setPositiveButton(
274+
android.R.string.yes,
275+
(dialog, whichButton) -> {
276+
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
277+
})
278+
.setNegativeButton(android.R.string.no, null)
279+
.show();
280+
}
281+
244282
private void setupModelSelectorDialog() {
245283
String[] pteFiles = listLocalFile("/data/local/tmp/llama/", ".pte");
246284
AlertDialog.Builder modelPathBuilder = new AlertDialog.Builder(this);
@@ -274,6 +312,29 @@ private static String[] listLocalFile(String path, String suffix) {
274312
return null;
275313
}
276314

315+
private void setupModelTypeSelectorDialog() {
316+
// Convert enum to list
317+
List<String> modelTypesList = new ArrayList<>();
318+
for (ModelType modelType : ModelType.values()) {
319+
modelTypesList.add(modelType.toString());
320+
}
321+
// Alert dialog builder takes in arr of string instead of list
322+
String[] modelTypes = modelTypesList.toArray(new String[0]);
323+
AlertDialog.Builder modelTypeBuilder = new AlertDialog.Builder(this);
324+
modelTypeBuilder.setTitle("Select model type");
325+
modelTypeBuilder.setSingleChoiceItems(
326+
modelTypes,
327+
-1,
328+
(dialog, item) -> {
329+
mModelTypeTextView.setText(modelTypes[item]);
330+
mModelType = ModelType.valueOf(modelTypes[item]);
331+
mUserPromptEditText.setText(PromptFormat.getUserPromptTemplate(mModelType));
332+
dialog.dismiss();
333+
});
334+
335+
modelTypeBuilder.create().show();
336+
}
337+
277338
private void setupTokenizerSelectorDialog() {
278339
String[] binFiles = listLocalFile("/data/local/tmp/llama/", ".bin");
279340
String[] tokenizerFiles = new String[binFiles.length];
@@ -314,6 +375,7 @@ private void saveSettings() {
314375
mSettingsFields.saveTokenizerPath(mTokenizerFilePath);
315376
mSettingsFields.saveParameters(mSetTemperature);
316377
mSettingsFields.savePrompts(mSystemPrompt, mUserPrompt);
378+
mSettingsFields.saveModelType(mModelType);
317379
mDemoSharedPreferences.addSettings(mSettingsFields);
318380
}
319381

0 commit comments

Comments
 (0)