File tree Expand file tree Collapse file tree 6 files changed +24
-5
lines changed
examples/demo-apps/android/LlamaDemo/app
androidTest/java/com/example/executorchllamademo
main/java/com/example/executorchllamademo
extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm Expand file tree Collapse file tree 6 files changed +24
-5
lines changed Original file line number Diff line number Diff line change @@ -60,6 +60,7 @@ dependencies {
6060 implementation(files(" libs/executorch.aar" ))
6161 implementation(" com.google.android.material:material:1.12.0" )
6262 implementation(" androidx.activity:activity:1.9.0" )
63+ implementation(" org.json:json:20250107" )
6364 testImplementation(" junit:junit:4.13.2" )
6465 androidTestImplementation(" androidx.test.ext:junit:1.1.5" )
6566 androidTestImplementation(" androidx.test.espresso:espresso-core:3.5.1" )
Original file line number Diff line number Diff line change 1818import java .util .ArrayList ;
1919import java .util .Arrays ;
2020import java .util .List ;
21+ import org .json .JSONObject ;
2122import org .junit .Test ;
2223import org .junit .runner .RunWith ;
2324import org .pytorch .executorch .extension .llm .LlmCallback ;
@@ -64,7 +65,12 @@ public void onResult(String result) {
6465 }
6566
6667 @ Override
67- public void onStats (float tps ) {
68+ public void onStats (String stats ) {
69+ JSONObject jsonObject = new JSONObject (stats );
70+ int numGeneratedTokens = jsonObject .getInt ("num_generated_tokens" );
71+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
72+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
73+ float tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
6874 tokensPerSecond .add (tps );
6975 }
7076
Original file line number Diff line number Diff line change 4949import java .util .List ;
5050import java .util .concurrent .Executor ;
5151import java .util .concurrent .Executors ;
52+ import org .json .JSONObject ;
5253import org .pytorch .executorch .extension .llm .LlmCallback ;
5354import org .pytorch .executorch .extension .llm .LlmModule ;
5455
@@ -97,10 +98,15 @@ public void onResult(String result) {
9798 }
9899
99100 @ Override
100- public void onStats (float tps ) {
101+ public void onStats (String result ) {
101102 runOnUiThread (
102103 () -> {
103104 if (mResultMessage != null ) {
105+ JSONObject jsonObject = new JSONObject (stats );
106+ int numGeneratedTokens = jsonObject .getInt ("num_generated_tokens" );
107+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
108+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
109+ float tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
104110 mResultMessage .setTokensPerSecond (tps );
105111 mMessageAdapter .notifyDataSetChanged ();
106112 }
Original file line number Diff line number Diff line change 1313import android .os .Looper ;
1414import android .os .Message ;
1515import androidx .annotation .NonNull ;
16+ import org .json .JSONObject ;
1617import org .pytorch .executorch .extension .llm .LlmCallback ;
1718import org .pytorch .executorch .extension .llm .LlmModule ;
1819
@@ -69,7 +70,12 @@ public void onResult(String result) {
6970 }
7071
7172 @ Override
72- public void onStats (float tps ) {
73+ public void onStats (String stats ) {
74+ JSONObject jsonObject = new JSONObject (stats );
75+ int numGeneratedTokens = jsonObject .getInt ("num_generated_tokens" );
76+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
77+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
78+ float tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
7379 mCallback .onStats ("tokens/second: " + tps );
7480 }
7581}
Original file line number Diff line number Diff line change @@ -18,7 +18,7 @@ public interface ModelRunnerCallback {
1818
1919 void onTokenGenerated (String token );
2020
21- void onStats (String token );
21+ void onStats (String stats );
2222
2323 void onGenerationStopped ();
2424}
Original file line number Diff line number Diff line change @@ -37,7 +37,7 @@ public interface LlmCallback {
3737 */
3838 @ Deprecated
3939 @ DoNotStrip
40- public void onStats (float tps );
40+ default public void onStats (float tps ) {}
4141
4242 /**
4343 * Called when the statistics for the generate() is available.
You can’t perform that action at this time.
0 commit comments