File tree Expand file tree Collapse file tree 6 files changed +24
-5
lines changed
examples/demo-apps/android/LlamaDemo/app
androidTest/java/com/example/executorchllamademo
main/java/com/example/executorchllamademo
extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm Expand file tree Collapse file tree 6 files changed +24
-5
lines changed Original file line number Diff line number Diff line change @@ -60,6 +60,7 @@ dependencies {
60
60
implementation(files(" libs/executorch.aar" ))
61
61
implementation(" com.google.android.material:material:1.12.0" )
62
62
implementation(" androidx.activity:activity:1.9.0" )
63
+ implementation(" org.json:json:20250107" )
63
64
testImplementation(" junit:junit:4.13.2" )
64
65
androidTestImplementation(" androidx.test.ext:junit:1.1.5" )
65
66
androidTestImplementation(" androidx.test.espresso:espresso-core:3.5.1" )
Original file line number Diff line number Diff line change 18
18
import java .util .ArrayList ;
19
19
import java .util .Arrays ;
20
20
import java .util .List ;
21
+ import org .json .JSONObject ;
21
22
import org .junit .Test ;
22
23
import org .junit .runner .RunWith ;
23
24
import org .pytorch .executorch .extension .llm .LlmCallback ;
@@ -64,7 +65,12 @@ public void onResult(String result) {
64
65
}
65
66
66
67
@ Override
67
- public void onStats (float tps ) {
68
+ public void onStats (String stats ) {
69
+ JSONObject jsonObject = new JSONObject (stats );
70
+ int numGeneratedTokens = jsonObject .getInt ("num_generated_tokens" );
71
+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
72
+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
73
+ float tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
68
74
tokensPerSecond .add (tps );
69
75
}
70
76
Original file line number Diff line number Diff line change 49
49
import java .util .List ;
50
50
import java .util .concurrent .Executor ;
51
51
import java .util .concurrent .Executors ;
52
+ import org .json .JSONObject ;
52
53
import org .pytorch .executorch .extension .llm .LlmCallback ;
53
54
import org .pytorch .executorch .extension .llm .LlmModule ;
54
55
@@ -97,10 +98,15 @@ public void onResult(String result) {
97
98
}
98
99
99
100
@ Override
100
- public void onStats (float tps ) {
101
+ public void onStats (String result ) {
101
102
runOnUiThread (
102
103
() -> {
103
104
if (mResultMessage != null ) {
105
+ JSONObject jsonObject = new JSONObject (stats );
106
+ int numGeneratedTokens = jsonObject .getInt ("num_generated_tokens" );
107
+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
108
+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
109
+ float tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
104
110
mResultMessage .setTokensPerSecond (tps );
105
111
mMessageAdapter .notifyDataSetChanged ();
106
112
}
Original file line number Diff line number Diff line change 13
13
import android .os .Looper ;
14
14
import android .os .Message ;
15
15
import androidx .annotation .NonNull ;
16
+ import org .json .JSONObject ;
16
17
import org .pytorch .executorch .extension .llm .LlmCallback ;
17
18
import org .pytorch .executorch .extension .llm .LlmModule ;
18
19
@@ -69,7 +70,12 @@ public void onResult(String result) {
69
70
}
70
71
71
72
@ Override
72
- public void onStats (float tps ) {
73
+ public void onStats (String stats ) {
74
+ JSONObject jsonObject = new JSONObject (stats );
75
+ int numGeneratedTokens = jsonObject .getInt ("num_generated_tokens" );
76
+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
77
+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
78
+ float tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
73
79
mCallback .onStats ("tokens/second: " + tps );
74
80
}
75
81
}
Original file line number Diff line number Diff line change @@ -18,7 +18,7 @@ public interface ModelRunnerCallback {
18
18
19
19
void onTokenGenerated (String token );
20
20
21
- void onStats (String token );
21
+ void onStats (String stats );
22
22
23
23
void onGenerationStopped ();
24
24
}
Original file line number Diff line number Diff line change @@ -37,7 +37,7 @@ public interface LlmCallback {
37
37
*/
38
38
@ Deprecated
39
39
@ DoNotStrip
40
- public void onStats (float tps );
40
+ default public void onStats (float tps ) {}
41
41
42
42
/**
43
43
* Called when the statistics for the generate() is available.
You can’t perform that action at this time.
0 commit comments