Skip to content

Commit 3b7b379

Browse files
committed
Update ONNX Runtime GenAI library and refactor MainActivity to use SimpleGenAI API
1 parent b64cc6b commit 3b7b379

File tree

8 files changed

+252
-208
lines changed

8 files changed

+252
-208
lines changed

mobile/examples/phi-3/android/app/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,6 @@ dependencies {
5252

5353
// ONNX Runtime with GenAI
5454
implementation("com.microsoft.onnxruntime:onnxruntime-android:latest.release")
55-
implementation(files("libs/onnxruntime-genai-android-0.4.0-dev.aar"))
55+
implementation(files("libs/onnxruntime-genai-android-0.8.1.aar"))
5656

5757
}
Binary file not shown.
Binary file not shown.

mobile/examples/phi-3/android/app/src/main/java/ai/onnxruntime/genai/demo/MainActivity.java

Lines changed: 51 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,16 @@
2222
import java.util.List;
2323
import java.util.concurrent.ExecutorService;
2424
import java.util.concurrent.Executors;
25+
import java.util.concurrent.atomic.AtomicInteger;
26+
import java.util.concurrent.atomic.AtomicLong;
2527
import java.util.function.Consumer;
2628

27-
import ai.onnxruntime.genai.GenAIException;
28-
import ai.onnxruntime.genai.Generator;
29-
import ai.onnxruntime.genai.GeneratorParams;
30-
import ai.onnxruntime.genai.Sequences;
31-
import ai.onnxruntime.genai.TokenizerStream;
32-
import ai.onnxruntime.genai.demo.databinding.ActivityMainBinding;
33-
import ai.onnxruntime.genai.Model;
34-
import ai.onnxruntime.genai.Tokenizer;
29+
import ai.onnxruntime.genai.*;
3530

3631
public class MainActivity extends AppCompatActivity implements Consumer<String> {
3732

38-
private ActivityMainBinding binding;
3933
private EditText userMsgEdt;
40-
private Model model;
41-
private Tokenizer tokenizer;
34+
private SimpleGenAI genAI;
4235
private ImageButton sendMsgIB;
4336
private TextView generatedTV;
4437
private TextView promptTV;
@@ -56,9 +49,7 @@ private static boolean fileExists(Context context, String fileName) {
5649
@Override
5750
protected void onCreate(Bundle savedInstanceState) {
5851
super.onCreate(savedInstanceState);
59-
60-
binding = ActivityMainBinding.inflate(getLayoutInflater());
61-
setContentView(binding.getRoot());
52+
setContentView(R.layout.activity_main);
6253

6354
sendMsgIB = findViewById(R.id.idIBSend);
6455
userMsgEdt = findViewById(R.id.idEdtMessage);
@@ -90,8 +81,6 @@ public void onSettingsApplied(int maxLength, float lengthPenalty) {
9081
});
9182

9283

93-
Consumer<String> tokenListener = this;
94-
9584
//enable scrolling and resizing of text boxes
9685
generatedTV.setMovementMethod(new ScrollingMovementMethod());
9786
getWindow().setSoftInputMode(WindowManager.LayoutParams.SOFT_INPUT_ADJUST_RESIZE);
@@ -100,7 +89,7 @@ public void onSettingsApplied(int maxLength, float lengthPenalty) {
10089
sendMsgIB.setOnClickListener(new View.OnClickListener() {
10190
@Override
10291
public void onClick(View v) {
103-
if (tokenizer == null) {
92+
if (genAI == null) {
10493
// if user tries to submit prompt while model is still downloading, display a toast message.
10594
Toast.makeText(MainActivity.this, "Model not loaded yet, please wait...", Toast.LENGTH_SHORT).show();
10695
return;
@@ -131,77 +120,57 @@ public void onClick(View v) {
131120
new Thread(new Runnable() {
132121
@Override
133122
public void run() {
134-
TokenizerStream stream = null;
135-
GeneratorParams generatorParams = null;
136-
Generator generator = null;
137-
Sequences encodedPrompt = null;
138123
try {
139-
stream = tokenizer.createStream();
140-
141-
generatorParams = model.createGeneratorParams();
142-
//examples for optional parameters to format AI response
124+
// Create generator parameters
125+
GeneratorParams generatorParams = genAI.createGeneratorParams();
126+
127+
// Set optional parameters to format AI response
143128
// https://onnxruntime.ai/docs/genai/reference/config.html
144-
generatorParams.setSearchOption("length_penalty", lengthPenalty);
145-
generatorParams.setSearchOption("max_length", maxLength);
146-
147-
encodedPrompt = tokenizer.encode(promptQuestion_formatted);
148-
generatorParams.setInput(encodedPrompt);
149-
150-
generator = new Generator(model, generatorParams);
151-
152-
// try to measure average time taken to generate each token.
129+
generatorParams.setSearchOption("length_penalty", (double)lengthPenalty);
130+
generatorParams.setSearchOption("max_length", (double)maxLength);
153131
long startTime = System.currentTimeMillis();
154-
long firstTokenTime = startTime;
155-
long currentTime = startTime;
156-
int numTokens = 0;
157-
while (!generator.isDone()) {
158-
generator.computeLogits();
159-
generator.generateNextToken();
160-
161-
int token = generator.getLastTokenInSequence(0);
162-
163-
if (numTokens == 0) { //first token
164-
firstTokenTime = System.currentTimeMillis();
132+
AtomicLong firstTokenTime = new AtomicLong(startTime);
133+
AtomicInteger numTokens = new AtomicInteger(0);
134+
135+
// Token listener for streaming tokens
136+
Consumer<String> tokenListener = token -> {
137+
if (numTokens.get() == 0) { // first token
138+
firstTokenTime.set(System.currentTimeMillis());
165139
}
166-
167-
tokenListener.accept(stream.decode(token));
168-
169-
170-
Log.i(TAG, "Generated token: " + token + ": " + stream.decode(token));
171-
Log.i(TAG, "Time taken to generate token: " + (System.currentTimeMillis() - currentTime)/ 1000.0 + " seconds");
172-
currentTime = System.currentTimeMillis();
173-
numTokens++;
174-
}
175-
long totalTime = System.currentTimeMillis() - firstTokenTime;
176-
177-
float promptProcessingTime = (firstTokenTime - startTime)/ 1000.0f;
178-
float tokensPerSecond = (1000 * (numTokens -1)) / totalTime;
140+
141+
// Update UI with new token
142+
MainActivity.this.accept(token);
143+
144+
Log.i(TAG, "Generated token: " + token);
145+
numTokens.incrementAndGet();
146+
};
147+
148+
String fullResponse = genAI.generate(generatorParams, promptQuestion_formatted, tokenListener);
149+
150+
long totalTime = System.currentTimeMillis() - firstTokenTime.get();
151+
float promptProcessingTime = (firstTokenTime.get() - startTime) / 1000.0f;
152+
float tokensPerSecond = numTokens.get() > 1 ? (1000.0f * (numTokens.get() - 1)) / totalTime : 0;
179153

180154
runOnUiThread(() -> {
181-
sendMsgIB.setEnabled(true);
182-
sendMsgIB.setAlpha(1.0f);
183-
184-
// Display the token generation rate in a dialog popup
185155
showTokenPopup(promptProcessingTime, tokensPerSecond);
186156
});
187157

158+
Log.i(TAG, "Full response: " + fullResponse);
188159
Log.i(TAG, "Prompt processing time (first token): " + promptProcessingTime + " seconds");
189160
Log.i(TAG, "Tokens generated per second (excluding prompt processing): " + tokensPerSecond);
190161
}
191162
catch (GenAIException e) {
192163
Log.e(TAG, "Exception occurred during model query: " + e.getMessage());
164+
runOnUiThread(() -> {
165+
Toast.makeText(MainActivity.this, "Error generating response: " + e.getMessage(), Toast.LENGTH_SHORT).show();
166+
});
193167
}
194168
finally {
195-
if (generator != null) generator.close();
196-
if (encodedPrompt != null) encodedPrompt.close();
197-
if (stream != null) stream.close();
198-
if (generatorParams != null) generatorParams.close();
169+
runOnUiThread(() -> {
170+
sendMsgIB.setEnabled(true);
171+
sendMsgIB.setAlpha(1.0f);
172+
});
199173
}
200-
201-
runOnUiThread(() -> {
202-
sendMsgIB.setEnabled(true);
203-
sendMsgIB.setAlpha(1.0f);
204-
});
205174
}
206175
}).start();
207176
}
@@ -210,10 +179,10 @@ public void run() {
210179

211180
@Override
212181
protected void onDestroy() {
213-
tokenizer.close();
214-
tokenizer = null;
215-
model.close();
216-
model = null;
182+
if (genAI != null) {
183+
genAI.close();
184+
genAI = null;
185+
}
217186
super.onDestroy();
218187
}
219188

@@ -244,8 +213,7 @@ private void downloadModels(Context context) throws GenAIException {
244213
// Display a message using Toast
245214
Toast.makeText(this, "All files already exist. Skipping download.", Toast.LENGTH_SHORT).show();
246215
Log.d(TAG, "All files already exist. Skipping download.");
247-
model = new Model(getFilesDir().getPath());
248-
tokenizer = model.createTokenizer();
216+
genAI = new SimpleGenAI(getFilesDir().getPath());
249217
return;
250218
}
251219

@@ -276,15 +244,18 @@ public void onDownloadComplete() {
276244

277245
// Last download completed, create SimpleGenAI
278246
try {
279-
model = new Model(getFilesDir().getPath());
280-
tokenizer = model.createTokenizer();
247+
genAI = new SimpleGenAI(getFilesDir().getPath());
281248
runOnUiThread(() -> {
282249
Toast.makeText(context, "All downloads completed", Toast.LENGTH_SHORT).show();
283250
progressText.setVisibility(View.INVISIBLE);
284251
});
285252
} catch (GenAIException e) {
286253
e.printStackTrace();
287-
throw new RuntimeException(e);
254+
Log.e(TAG, "Failed to initialize SimpleGenAI: " + e.getMessage());
255+
runOnUiThread(() -> {
256+
Toast.makeText(context, "Failed to load model: " + e.getMessage(), Toast.LENGTH_LONG).show();
257+
progressText.setText("Failed to load model");
258+
});
288259
}
289260

290261
}
-15.3 KB
Binary file not shown.
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
#Mon Mar 25 10:44:29 AEST 2024
21
distributionBase=GRADLE_USER_HOME
32
distributionPath=wrapper/dists
4-
distributionUrl=https\://services.gradle.org/distributions/gradle-8.0-bin.zip
3+
distributionUrl=https\://services.gradle.org/distributions/gradle-8.9-bin.zip
4+
networkTimeout=10000
5+
validateDistributionUrl=true
56
zipStoreBase=GRADLE_USER_HOME
67
zipStorePath=wrapper/dists

0 commit comments

Comments
 (0)