Skip to content

Commit 7a635da

Browse files
vrasparedgchen1
andauthored
Update Phi3 Android App (#520)
* Update ONNX Runtime GenAI library and refactor MainActivity to use SimpleGenAI API * Refactor model configuration in MainActivity and update README * Refactor imports in MainActivity to include specific GenAI classes * Update mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java Co-authored-by: Edward Chen <[email protected]> * Refactor token handling in MainActivity to use array instead of Atomic types --------- Co-authored-by: Edward Chen <[email protected]>
1 parent b64cc6b commit 7a635da

File tree

9 files changed

+303
-226
lines changed

9 files changed

+303
-226
lines changed

mobile/examples/phi-3/android/README.md

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,38 @@
1-
# Local Chatbot on Android with Phi-3, ONNX Runtime Mobile and ONNX Runtime Generate() API
1+
# Local Chatbot on Android with ONNX Runtime Mobile and ONNX Runtime Generate() API
22

33
## Overview
44

5-
This is a basic [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) Android example application with [ONNX Runtime mobile](https://onnxruntime.ai/docs/tutorials/mobile/) and [ONNX Runtime Generate() API](https://github.com/microsoft/onnxruntime-genai) with support for efficiently running generative AI models. This app demonstrates the usage of phi-3 model in a simple question answering chatbot mode.
5+
This is a flexible Android chatbot application with [ONNX Runtime mobile](https://onnxruntime.ai/docs/tutorials/mobile/) and [ONNX Runtime Generate() API](https://github.com/microsoft/onnxruntime-genai) that supports efficiently running generative AI models. While it uses [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) by default, **it can work with any ONNX Runtime GenAI compatible model** by simply updating the model configuration in the code.
66

77
### Model
8-
The model used here is [ONNX Phi-3 model on HuggingFace](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) with INT4 quantization and optimizations for mobile usage.
9-
10-
You can also optimize your fine-tuned PyTorch Phi-3 model for mobile usage following this example [Phi3 optimization with Olive](https://github.com/microsoft/Olive/tree/main/examples/phi3).
8+
By default, this app uses the [ONNX Phi-3 model on HuggingFace](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) with INT4 quantization and optimizations for mobile usage.
9+
10+
### Using Different Models
11+
**The app is designed to work with any ONNX Runtime GenAI compatible model.** To use a different model:
12+
13+
1. Open `MainActivity.java` in Android Studio
14+
2. Locate the model configuration section at the top of the class (marked with comments)
15+
3. Update the `MODEL_BASE_URL` to point to your model's download location
16+
4. Update the `MODEL_FILES` list to include all required files for your model
17+
18+
Example for a different model:
19+
```java
20+
// Base URL for downloading model files (ensure it ends with '/')
21+
private static final String MODEL_BASE_URL = "https://your-model-host.com/path/to/model/";
22+
23+
// List of required model files to download
24+
private static final List<String> MODEL_FILES = Arrays.asList(
25+
"config.json",
26+
"genai_config.json",
27+
"your-model.onnx",
28+
"your-model.onnx.data",
29+
"tokenizer.json",
30+
"tokenizer_config.json"
31+
// Add other required files...
32+
);
33+
```
34+
35+
**Note:** The model files will be downloaded to `/data/data/ai.onnxruntime.genai.demo/files` on the Android device.
1136

1237
### Requirements
1338
- Android Studio Giraffe | 2022.3.1 or later (installed on Mac/Windows/Linux)
@@ -30,7 +55,7 @@ The current set up supports downloading Phi-3-mini model directly from Huggingfa
3055
You can also follow this link to download **Phi-3-mini**: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4
3156
and manually copy to the android device file directory following the below instructions:
3257

33-
#### Steps for manual copying models to android device directory:
58+
#### Steps for manual copying model files to android device directory:
3459
From Android Studio:
3560
- create (if necessary) and run your emulator/device
3661
- make sure it has at least 8GB of internal storage
@@ -40,7 +65,8 @@ From Android Studio:
4065
- Open Device Explorer in Android Studio
4166
- Navigate to `/data/data/ai.onnxruntime.genai.demo/files`
4267
- adjust as needed if the value returned by getFilesDir() differs for your emulator or device
43-
- copy the whole [phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4) model folder to the `files` directory
68+
- copy all the required model files (as specified in `MODEL_FILES` in MainActivity.java) directly to the `files` directory
69+
- For the default Phi-3 model, copy files from [here](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4)
4470

4571
### Step 3: Connect Android Device and Run the app
4672
Connect your Android Device to your computer or select the Android Emulator in Android Studio Device manager.

mobile/examples/phi-3/android/app/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,6 @@ dependencies {
5252

5353
// ONNX Runtime with GenAI
5454
implementation("com.microsoft.onnxruntime:onnxruntime-android:latest.release")
55-
implementation(files("libs/onnxruntime-genai-android-0.4.0-dev.aar"))
55+
implementation(files("libs/onnxruntime-genai-android-0.8.1.aar"))
5656

5757
}
Binary file not shown.
Binary file not shown.

mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java

Lines changed: 69 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,33 @@
2424
import java.util.concurrent.Executors;
2525
import java.util.function.Consumer;
2626

27+
import ai.onnxruntime.genai.SimpleGenAI;
2728
import ai.onnxruntime.genai.GenAIException;
28-
import ai.onnxruntime.genai.Generator;
2929
import ai.onnxruntime.genai.GeneratorParams;
30-
import ai.onnxruntime.genai.Sequences;
31-
import ai.onnxruntime.genai.TokenizerStream;
32-
import ai.onnxruntime.genai.demo.databinding.ActivityMainBinding;
33-
import ai.onnxruntime.genai.Model;
34-
import ai.onnxruntime.genai.Tokenizer;
3530

3631
public class MainActivity extends AppCompatActivity implements Consumer<String> {
3732

38-
private ActivityMainBinding binding;
33+
// ===== MODEL CONFIGURATION - MODIFY THESE FOR DIFFERENT MODELS =====
34+
// Base URL for downloading model files (ensure it ends with '/')
35+
private static final String MODEL_BASE_URL = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/";
36+
37+
// List of required model files to download
38+
private static final List<String> MODEL_FILES = Arrays.asList(
39+
"added_tokens.json",
40+
"config.json",
41+
"configuration_phi3.py",
42+
"genai_config.json",
43+
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx",
44+
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data",
45+
"special_tokens_map.json",
46+
"tokenizer.json",
47+
"tokenizer.model",
48+
"tokenizer_config.json"
49+
);
50+
// ===== END MODEL CONFIGURATION =====
51+
3952
private EditText userMsgEdt;
40-
private Model model;
41-
private Tokenizer tokenizer;
53+
private SimpleGenAI genAI;
4254
private ImageButton sendMsgIB;
4355
private TextView generatedTV;
4456
private TextView promptTV;
@@ -56,9 +68,7 @@ private static boolean fileExists(Context context, String fileName) {
5668
@Override
5769
protected void onCreate(Bundle savedInstanceState) {
5870
super.onCreate(savedInstanceState);
59-
60-
binding = ActivityMainBinding.inflate(getLayoutInflater());
61-
setContentView(binding.getRoot());
71+
setContentView(R.layout.activity_main);
6272

6373
sendMsgIB = findViewById(R.id.idIBSend);
6474
userMsgEdt = findViewById(R.id.idEdtMessage);
@@ -90,8 +100,6 @@ public void onSettingsApplied(int maxLength, float lengthPenalty) {
90100
});
91101

92102

93-
Consumer<String> tokenListener = this;
94-
95103
//enable scrolling and resizing of text boxes
96104
generatedTV.setMovementMethod(new ScrollingMovementMethod());
97105
getWindow().setSoftInputMode(WindowManager.LayoutParams.SOFT_INPUT_ADJUST_RESIZE);
@@ -100,7 +108,7 @@ public void onSettingsApplied(int maxLength, float lengthPenalty) {
100108
sendMsgIB.setOnClickListener(new View.OnClickListener() {
101109
@Override
102110
public void onClick(View v) {
103-
if (tokenizer == null) {
111+
if (genAI == null) {
104112
// if user tries to submit prompt while model is still downloading, display a toast message.
105113
Toast.makeText(MainActivity.this, "Model not loaded yet, please wait...", Toast.LENGTH_SHORT).show();
106114
return;
@@ -131,77 +139,58 @@ public void onClick(View v) {
131139
new Thread(new Runnable() {
132140
@Override
133141
public void run() {
134-
TokenizerStream stream = null;
135-
GeneratorParams generatorParams = null;
136-
Generator generator = null;
137-
Sequences encodedPrompt = null;
138142
try {
139-
stream = tokenizer.createStream();
140-
141-
generatorParams = model.createGeneratorParams();
142-
//examples for optional parameters to format AI response
143+
// Create generator parameters
144+
GeneratorParams generatorParams = genAI.createGeneratorParams();
145+
146+
// Set optional parameters to format AI response
143147
// https://onnxruntime.ai/docs/genai/reference/config.html
144-
generatorParams.setSearchOption("length_penalty", lengthPenalty);
145-
generatorParams.setSearchOption("max_length", maxLength);
146-
147-
encodedPrompt = tokenizer.encode(promptQuestion_formatted);
148-
generatorParams.setInput(encodedPrompt);
149-
150-
generator = new Generator(model, generatorParams);
151-
152-
// try to measure average time taken to generate each token.
148+
generatorParams.setSearchOption("length_penalty", (double)lengthPenalty);
149+
generatorParams.setSearchOption("max_length", (double)maxLength);
153150
long startTime = System.currentTimeMillis();
154-
long firstTokenTime = startTime;
155-
long currentTime = startTime;
156-
int numTokens = 0;
157-
while (!generator.isDone()) {
158-
generator.computeLogits();
159-
generator.generateNextToken();
160-
161-
int token = generator.getLastTokenInSequence(0);
162-
163-
if (numTokens == 0) { //first token
164-
firstTokenTime = System.currentTimeMillis();
151+
final long[] firstTokenTime = {startTime};
152+
final long[] numTokens = {0};
153+
154+
// Token listener for streaming tokens
155+
Consumer<String> tokenListener = token -> {
156+
if (numTokens[0] == 0) {
157+
firstTokenTime[0] = System.currentTimeMillis();
165158
}
166159

167-
tokenListener.accept(stream.decode(token));
160+
161+
// Update UI with new token
162+
MainActivity.this.accept(token);
163+
164+
Log.i(TAG, "Generated token: " + token);
165+
numTokens[0] += 1;
166+
};
168167

169-
170-
Log.i(TAG, "Generated token: " + token + ": " + stream.decode(token));
171-
Log.i(TAG, "Time taken to generate token: " + (System.currentTimeMillis() - currentTime)/ 1000.0 + " seconds");
172-
currentTime = System.currentTimeMillis();
173-
numTokens++;
174-
}
175-
long totalTime = System.currentTimeMillis() - firstTokenTime;
176-
177-
float promptProcessingTime = (firstTokenTime - startTime)/ 1000.0f;
178-
float tokensPerSecond = (1000 * (numTokens -1)) / totalTime;
168+
String fullResponse = genAI.generate(generatorParams, promptQuestion_formatted, tokenListener);
169+
170+
long totalTime = System.currentTimeMillis() - firstTokenTime[0];
171+
float promptProcessingTime = (firstTokenTime[0] - startTime) / 1000.0f;
172+
float tokensPerSecond = numTokens[0] > 1 ? (1000.0f * (numTokens[0] - 1)) / totalTime : 0;
179173

180174
runOnUiThread(() -> {
181-
sendMsgIB.setEnabled(true);
182-
sendMsgIB.setAlpha(1.0f);
183-
184-
// Display the token generation rate in a dialog popup
185175
showTokenPopup(promptProcessingTime, tokensPerSecond);
186176
});
187177

178+
Log.i(TAG, "Full response: " + fullResponse);
188179
Log.i(TAG, "Prompt processing time (first token): " + promptProcessingTime + " seconds");
189180
Log.i(TAG, "Tokens generated per second (excluding prompt processing): " + tokensPerSecond);
190181
}
191182
catch (GenAIException e) {
192183
Log.e(TAG, "Exception occurred during model query: " + e.getMessage());
184+
runOnUiThread(() -> {
185+
Toast.makeText(MainActivity.this, "Error generating response: " + e.getMessage(), Toast.LENGTH_SHORT).show();
186+
});
193187
}
194188
finally {
195-
if (generator != null) generator.close();
196-
if (encodedPrompt != null) encodedPrompt.close();
197-
if (stream != null) stream.close();
198-
if (generatorParams != null) generatorParams.close();
189+
runOnUiThread(() -> {
190+
sendMsgIB.setEnabled(true);
191+
sendMsgIB.setAlpha(1.0f);
192+
});
199193
}
200-
201-
runOnUiThread(() -> {
202-
sendMsgIB.setEnabled(true);
203-
sendMsgIB.setAlpha(1.0f);
204-
});
205194
}
206195
}).start();
207196
}
@@ -210,42 +199,28 @@ public void run() {
210199

211200
@Override
212201
protected void onDestroy() {
213-
tokenizer.close();
214-
tokenizer = null;
215-
model.close();
216-
model = null;
202+
if (genAI != null) {
203+
genAI.close();
204+
genAI = null;
205+
}
217206
super.onDestroy();
218207
}
219208

220209
private void downloadModels(Context context) throws GenAIException {
221210

222-
final String baseUrl = "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx/resolve/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/";
223-
List<String> files = Arrays.asList(
224-
"added_tokens.json",
225-
"config.json",
226-
"configuration_phi3.py",
227-
"genai_config.json",
228-
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx",
229-
"phi3-mini-4k-instruct-cpu-int4-rtn-block-32-acc-level-4.onnx.data",
230-
"special_tokens_map.json",
231-
"tokenizer.json",
232-
"tokenizer.model",
233-
"tokenizer_config.json");
234-
235211
List<Pair<String, String>> urlFilePairs = new ArrayList<>();
236-
for (String file : files) {
212+
for (String file : MODEL_FILES) {
237213
if (!fileExists(context, file)) {
238214
urlFilePairs.add(new Pair<>(
239-
baseUrl + file,
215+
MODEL_BASE_URL + file,
240216
file));
241217
}
242218
}
243219
if (urlFilePairs.isEmpty()) {
244220
// Display a message using Toast
245221
Toast.makeText(this, "All files already exist. Skipping download.", Toast.LENGTH_SHORT).show();
246222
Log.d(TAG, "All files already exist. Skipping download.");
247-
model = new Model(getFilesDir().getPath());
248-
tokenizer = model.createTokenizer();
223+
genAI = new SimpleGenAI(getFilesDir().getPath());
249224
return;
250225
}
251226

@@ -276,15 +251,18 @@ public void onDownloadComplete() {
276251

277252
// Last download completed, create SimpleGenAI
278253
try {
279-
model = new Model(getFilesDir().getPath());
280-
tokenizer = model.createTokenizer();
254+
genAI = new SimpleGenAI(getFilesDir().getPath());
281255
runOnUiThread(() -> {
282256
Toast.makeText(context, "All downloads completed", Toast.LENGTH_SHORT).show();
283257
progressText.setVisibility(View.INVISIBLE);
284258
});
285259
} catch (GenAIException e) {
286260
e.printStackTrace();
287-
throw new RuntimeException(e);
261+
Log.e(TAG, "Failed to initialize SimpleGenAI: " + e.getMessage());
262+
runOnUiThread(() -> {
263+
Toast.makeText(context, "Failed to load model: " + e.getMessage(), Toast.LENGTH_LONG).show();
264+
progressText.setText("Failed to load model");
265+
});
288266
}
289267

290268
}
-15.3 KB
Binary file not shown.
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
#Mon Mar 25 10:44:29 AEST 2024
21
distributionBase=GRADLE_USER_HOME
32
distributionPath=wrapper/dists
4-
distributionUrl=https\://services.gradle.org/distributions/gradle-8.0-bin.zip
3+
distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip
4+
networkTimeout=10000
5+
validateDistributionUrl=true
56
zipStoreBase=GRADLE_USER_HOME
67
zipStorePath=wrapper/dists

0 commit comments

Comments
 (0)