Skip to content

Commit 4944842

Browse files
committed
Add backend parameter to ModelUtil
1 parent 7722e08 commit 4944842

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera
125125
long runStartTime = System.currentTimeMillis();
126126
mModule =
127127
new LlamaModule(
128-
//ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()),
129-
3, //TODO: Modify this based on JNI change for how to select MTK backend
128+
ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()),
130129
modelPath,
131130
tokenizerPath,
132131
temperature);
@@ -175,6 +174,10 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera
175174
+ modelPath
176175
+ "\nTokenizer path: "
177176
+ tokenizerPath
177+
+ "\nBackend: "
178+
+ mCurrentSettingsFields.getBackendType().toString()
179+
+ "\nModelType: "
180+
+ ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType())
178181
+ "\nTemperature: "
179182
+ temperature
180183
+ "\nModel loaded time: "
@@ -692,7 +695,7 @@ private void onModelRunStopped() {
692695
addSelectedImagesToChatThread(mSelectedImageUri);
693696
String finalPrompt;
694697
String rawPrompt = mEditTextMessage.getText().toString();
695-
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
698+
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType())
696699
== ModelUtils.VISION_MODEL) {
697700
finalPrompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt);
698701
} else {
@@ -725,7 +728,7 @@ public void run() {
725728
}
726729
});
727730
long generateStartTime = System.currentTimeMillis();
728-
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType())
731+
if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType())
729732
== ModelUtils.VISION_MODEL) {
730733
mModule.generateFromPos(
731734
finalPrompt,

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,33 @@
99
package com.example.executorchllamademo;
1010

1111
public class ModelUtils {
12+
// XNNPACK or QNN
1213
static final int TEXT_MODEL = 1;
14+
15+
// XNNPACK
1316
static final int VISION_MODEL = 2;
1417
static final int VISION_MODEL_IMAGE_CHANNELS = 3;
15-
//TODO: Make change here based on JNI change on how to indicate MTK backend
1618
static final int VISION_MODEL_SEQ_LEN = 768;
1719
static final int TEXT_MODEL_SEQ_LEN = 256;
1820

19-
public static int getModelCategory(ModelType modelType) {
20-
switch (modelType) {
21-
case LLAVA_1_5:
22-
return VISION_MODEL;
23-
case LLAMA_3:
24-
case LLAMA_3_1:
25-
case LLAMA_3_2:
26-
default:
27-
return TEXT_MODEL;
21+
// MediaTek
22+
static final int MEDIATEK_TEXT_MODEL = 3;
23+
24+
public static int getModelCategory(ModelType modelType, BackendType backendType) {
25+
if (backendType.equals(BackendType.XNNPACK)) {
26+
switch (modelType) {
27+
case LLAVA_1_5:
28+
return VISION_MODEL;
29+
case LLAMA_3:
30+
case LLAMA_3_1:
31+
case LLAMA_3_2:
32+
default:
33+
return TEXT_MODEL;
34+
}
35+
} else if (backendType.equals(BackendType.MEDIATEK)) {
36+
return MEDIATEK_TEXT_MODEL;
2837
}
38+
39+
return TEXT_MODEL; // default
2940
}
3041
}

0 commit comments

Comments
 (0)