2222import java .util .List ;
2323import java .util .concurrent .ExecutorService ;
2424import java .util .concurrent .Executors ;
25+ import java .util .concurrent .atomic .AtomicInteger ;
26+ import java .util .concurrent .atomic .AtomicLong ;
2527import java .util .function .Consumer ;
2628
27- import ai .onnxruntime .genai .GenAIException ;
28- import ai .onnxruntime .genai .Generator ;
29- import ai .onnxruntime .genai .GeneratorParams ;
30- import ai .onnxruntime .genai .Sequences ;
31- import ai .onnxruntime .genai .TokenizerStream ;
32- import ai .onnxruntime .genai .demo .databinding .ActivityMainBinding ;
33- import ai .onnxruntime .genai .Model ;
34- import ai .onnxruntime .genai .Tokenizer ;
29+ import ai .onnxruntime .genai .*;
3530
3631public class MainActivity extends AppCompatActivity implements Consumer <String > {
3732
38- private ActivityMainBinding binding ;
3933 private EditText userMsgEdt ;
40- private Model model ;
41- private Tokenizer tokenizer ;
34+ private SimpleGenAI genAI ;
4235 private ImageButton sendMsgIB ;
4336 private TextView generatedTV ;
4437 private TextView promptTV ;
@@ -56,9 +49,7 @@ private static boolean fileExists(Context context, String fileName) {
5649 @ Override
5750 protected void onCreate (Bundle savedInstanceState ) {
5851 super .onCreate (savedInstanceState );
59-
60- binding = ActivityMainBinding .inflate (getLayoutInflater ());
61- setContentView (binding .getRoot ());
52+ setContentView (R .layout .activity_main );
6253
6354 sendMsgIB = findViewById (R .id .idIBSend );
6455 userMsgEdt = findViewById (R .id .idEdtMessage );
@@ -90,8 +81,6 @@ public void onSettingsApplied(int maxLength, float lengthPenalty) {
9081 });
9182
9283
93- Consumer <String > tokenListener = this ;
94-
9584 //enable scrolling and resizing of text boxes
9685 generatedTV .setMovementMethod (new ScrollingMovementMethod ());
9786 getWindow ().setSoftInputMode (WindowManager .LayoutParams .SOFT_INPUT_ADJUST_RESIZE );
@@ -100,7 +89,7 @@ public void onSettingsApplied(int maxLength, float lengthPenalty) {
10089 sendMsgIB .setOnClickListener (new View .OnClickListener () {
10190 @ Override
10291 public void onClick (View v ) {
103- if (tokenizer == null ) {
92+ if (genAI == null ) {
10493 // if user tries to submit prompt while model is still downloading, display a toast message.
10594 Toast .makeText (MainActivity .this , "Model not loaded yet, please wait..." , Toast .LENGTH_SHORT ).show ();
10695 return ;
@@ -131,77 +120,57 @@ public void onClick(View v) {
131120 new Thread (new Runnable () {
132121 @ Override
133122 public void run () {
134- TokenizerStream stream = null ;
135- GeneratorParams generatorParams = null ;
136- Generator generator = null ;
137- Sequences encodedPrompt = null ;
138123 try {
139- stream = tokenizer . createStream ();
140-
141- generatorParams = model . createGeneratorParams ();
142- //examples for optional parameters to format AI response
124+ // Create generator parameters
125+ GeneratorParams generatorParams = genAI . createGeneratorParams ();
126+
127+ // Set optional parameters to format AI response
143128 // https://onnxruntime.ai/docs/genai/reference/config.html
144- generatorParams .setSearchOption ("length_penalty" , lengthPenalty );
145- generatorParams .setSearchOption ("max_length" , maxLength );
146-
147- encodedPrompt = tokenizer .encode (promptQuestion_formatted );
148- generatorParams .setInput (encodedPrompt );
149-
150- generator = new Generator (model , generatorParams );
151-
152- // try to measure average time taken to generate each token.
129+ generatorParams .setSearchOption ("length_penalty" , (double )lengthPenalty );
130+ generatorParams .setSearchOption ("max_length" , (double )maxLength );
153131 long startTime = System .currentTimeMillis ();
154- long firstTokenTime = startTime ;
155- long currentTime = startTime ;
156- int numTokens = 0 ;
157- while (!generator .isDone ()) {
158- generator .computeLogits ();
159- generator .generateNextToken ();
160-
161- int token = generator .getLastTokenInSequence (0 );
162-
163- if (numTokens == 0 ) { //first token
164- firstTokenTime = System .currentTimeMillis ();
132+ AtomicLong firstTokenTime = new AtomicLong (startTime );
133+ AtomicInteger numTokens = new AtomicInteger (0 );
134+
135+ // Token listener for streaming tokens
136+ Consumer <String > tokenListener = token -> {
137+ if (numTokens .get () == 0 ) { // first token
138+ firstTokenTime .set (System .currentTimeMillis ());
165139 }
166-
167- tokenListener . accept ( stream . decode ( token ));
168-
169-
170- Log .i (TAG , "Generated token: " + token + ": " + stream . decode ( token ) );
171- Log . i ( TAG , "Time taken to generate token: " + ( System . currentTimeMillis () - currentTime )/ 1000.0 + " seconds" );
172- currentTime = System . currentTimeMillis () ;
173- numTokens ++;
174- }
175- long totalTime = System . currentTimeMillis () - firstTokenTime ;
176-
177- float promptProcessingTime = (firstTokenTime - startTime )/ 1000.0f ;
178- float tokensPerSecond = ( 1000 * (numTokens - 1 )) / totalTime ;
140+
141+ // Update UI with new token
142+ MainActivity . this . accept ( token );
143+
144+ Log .i (TAG , "Generated token: " + token );
145+ numTokens . incrementAndGet ( );
146+ } ;
147+
148+ String fullResponse = genAI . generate ( generatorParams , promptQuestion_formatted , tokenListener );
149+
150+ long totalTime = System . currentTimeMillis () - firstTokenTime . get ();
151+ float promptProcessingTime = (firstTokenTime . get () - startTime ) / 1000.0f ;
152+ float tokensPerSecond = numTokens . get () > 1 ? ( 1000.0f * (numTokens . get () - 1 )) / totalTime : 0 ;
179153
180154 runOnUiThread (() -> {
181- sendMsgIB .setEnabled (true );
182- sendMsgIB .setAlpha (1.0f );
183-
184- // Display the token generation rate in a dialog popup
185155 showTokenPopup (promptProcessingTime , tokensPerSecond );
186156 });
187157
158+ Log .i (TAG , "Full response: " + fullResponse );
188159 Log .i (TAG , "Prompt processing time (first token): " + promptProcessingTime + " seconds" );
189160 Log .i (TAG , "Tokens generated per second (excluding prompt processing): " + tokensPerSecond );
190161 }
191162 catch (GenAIException e ) {
192163 Log .e (TAG , "Exception occurred during model query: " + e .getMessage ());
164+ runOnUiThread (() -> {
165+ Toast .makeText (MainActivity .this , "Error generating response: " + e .getMessage (), Toast .LENGTH_SHORT ).show ();
166+ });
193167 }
194168 finally {
195- if ( generator != null ) generator . close ();
196- if ( encodedPrompt != null ) encodedPrompt . close ( );
197- if ( stream != null ) stream . close ( );
198- if ( generatorParams != null ) generatorParams . close ( );
169+ runOnUiThread (() -> {
170+ sendMsgIB . setEnabled ( true );
171+ sendMsgIB . setAlpha ( 1.0f );
172+ } );
199173 }
200-
201- runOnUiThread (() -> {
202- sendMsgIB .setEnabled (true );
203- sendMsgIB .setAlpha (1.0f );
204- });
205174 }
206175 }).start ();
207176 }
@@ -210,10 +179,10 @@ public void run() {
210179
211180 @ Override
212181 protected void onDestroy () {
213- tokenizer . close ();
214- tokenizer = null ;
215- model . close () ;
216- model = null ;
182+ if ( genAI != null ) {
183+ genAI . close () ;
184+ genAI = null ;
185+ }
217186 super .onDestroy ();
218187 }
219188
@@ -244,8 +213,7 @@ private void downloadModels(Context context) throws GenAIException {
244213 // Display a message using Toast
245214 Toast .makeText (this , "All files already exist. Skipping download." , Toast .LENGTH_SHORT ).show ();
246215 Log .d (TAG , "All files already exist. Skipping download." );
247- model = new Model (getFilesDir ().getPath ());
248- tokenizer = model .createTokenizer ();
216+ genAI = new SimpleGenAI (getFilesDir ().getPath ());
249217 return ;
250218 }
251219
@@ -276,15 +244,18 @@ public void onDownloadComplete() {
276244
277245 // Last download completed, create SimpleGenAI
278246 try {
279- model = new Model (getFilesDir ().getPath ());
280- tokenizer = model .createTokenizer ();
247+ genAI = new SimpleGenAI (getFilesDir ().getPath ());
281248 runOnUiThread (() -> {
282249 Toast .makeText (context , "All downloads completed" , Toast .LENGTH_SHORT ).show ();
283250 progressText .setVisibility (View .INVISIBLE );
284251 });
285252 } catch (GenAIException e ) {
286253 e .printStackTrace ();
287- throw new RuntimeException (e );
254+ Log .e (TAG , "Failed to initialize SimpleGenAI: " + e .getMessage ());
255+ runOnUiThread (() -> {
256+ Toast .makeText (context , "Failed to load model: " + e .getMessage (), Toast .LENGTH_LONG ).show ();
257+ progressText .setText ("Failed to load model" );
258+ });
288259 }
289260
290261 }
0 commit comments