File tree Expand file tree Collapse file tree 13 files changed +97
-27
lines changed
examples/demo-apps/android/LlamaDemo/app
androidTest/java/com/example/executorchllamademo
main/java/com/example/executorchllamademo
androidTest/java/org/pytorch/executorch
main/java/org/pytorch/executorch/extension/llm
benchmark/android/benchmark/app
src/main/java/org/pytorch/minibench Expand file tree Collapse file tree 13 files changed +97
-27
lines changed Original file line number Diff line number Diff line change @@ -60,6 +60,7 @@ dependencies {
6060 implementation(files(" libs/executorch.aar" ))
6161 implementation(" com.google.android.material:material:1.12.0" )
6262 implementation(" androidx.activity:activity:1.9.0" )
63+ implementation(" org.json:json:20250107" )
6364 testImplementation(" junit:junit:4.13.2" )
6465 androidTestImplementation(" androidx.test.ext:junit:1.1.5" )
6566 androidTestImplementation(" androidx.test.espresso:espresso-core:3.5.1" )
Original file line number Diff line number Diff line change 1818import java .util .ArrayList ;
1919import java .util .Arrays ;
2020import java .util .List ;
21+ import org .json .JSONException ;
22+ import org .json .JSONObject ;
2123import org .junit .Test ;
2224import org .junit .runner .RunWith ;
2325import org .pytorch .executorch .extension .llm .LlmCallback ;
@@ -64,8 +66,16 @@ public void onResult(String result) {
6466 }
6567
6668 @ Override
67- public void onStats (float tps ) {
68- tokensPerSecond .add (tps );
69+ public void onStats (String result ) {
70+ try {
71+ JSONObject jsonObject = new JSONObject (result );
72+ int numGeneratedTokens = jsonObject .getInt ("generated_tokens" );
73+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
74+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
75+ float tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
76+ tokensPerSecond .add (tps );
77+ } catch (JSONException e ) {
78+ }
6979 }
7080
7181 private void report (final String metric , final Float value ) {
Original file line number Diff line number Diff line change 4949import java .util .List ;
5050import java .util .concurrent .Executor ;
5151import java .util .concurrent .Executors ;
52+ import org .json .JSONException ;
53+ import org .json .JSONObject ;
5254import org .pytorch .executorch .extension .llm .LlmCallback ;
5355import org .pytorch .executorch .extension .llm .LlmModule ;
5456
@@ -97,10 +99,20 @@ public void onResult(String result) {
9799 }
98100
99101 @ Override
100- public void onStats (float tps ) {
102+ public void onStats (String stats ) {
101103 runOnUiThread (
102104 () -> {
103105 if (mResultMessage != null ) {
106+ float tps = 0 ;
107+ try {
108+ JSONObject jsonObject = new JSONObject (stats );
109+ int numGeneratedTokens = jsonObject .getInt ("generated_tokens" );
110+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
111+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
112+ tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
113+ } catch (JSONException e ) {
114+ Log .e ("LLM" , "Error parsing JSON: " + e .getMessage ());
115+ }
104116 mResultMessage .setTokensPerSecond (tps );
105117 mMessageAdapter .notifyDataSetChanged ();
106118 }
Original file line number Diff line number Diff line change 1313import android .os .Looper ;
1414import android .os .Message ;
1515import androidx .annotation .NonNull ;
16+ import org .json .JSONException ;
17+ import org .json .JSONObject ;
1618import org .pytorch .executorch .extension .llm .LlmCallback ;
1719import org .pytorch .executorch .extension .llm .LlmModule ;
1820
@@ -69,7 +71,16 @@ public void onResult(String result) {
6971 }
7072
7173 @ Override
72- public void onStats (float tps ) {
74+ public void onStats (String stats ) {
75+ float tps = 0 ;
76+ try {
77+ JSONObject jsonObject = new JSONObject (stats );
78+ int numGeneratedTokens = jsonObject .getInt ("generated_tokens" );
79+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
80+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
81+ tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
82+ } catch (JSONException e ) {
83+ }
7384 mCallback .onStats ("tokens/second: " + tps );
7485 }
7586}
Original file line number Diff line number Diff line change @@ -18,7 +18,7 @@ public interface ModelRunnerCallback {
1818
1919 void onTokenGenerated (String token );
2020
21- void onStats (String token );
21+ void onStats (String stats );
2222
2323 void onGenerationStopped ();
2424}
Original file line number Diff line number Diff line change @@ -47,6 +47,7 @@ dependencies {
4747 androidTestImplementation ' androidx.test.ext:junit:1.1.5'
4848 androidTestImplementation ' androidx.test:rules:1.2.0'
4949 androidTestImplementation ' commons-io:commons-io:2.4'
50+ androidTestImplementation ' org.json:json:20250107'
5051}
5152
5253import com.vanniktech.maven.publish.SonatypeHost
Original file line number Diff line number Diff line change 3434import org .apache .commons .io .FileUtils ;
3535import androidx .test .ext .junit .runners .AndroidJUnit4 ;
3636import androidx .test .InstrumentationRegistry ;
37+ import org .json .JSONException ;
38+ import org .json .JSONObject ;
3739import org .pytorch .executorch .extension .llm .LlmCallback ;
3840import org .pytorch .executorch .extension .llm .LlmModule ;
3941
@@ -94,8 +96,17 @@ public void onResult(String result) {
9496 }
9597
9698 @ Override
97- public void onStats (float tps ) {
98- LlmModuleInstrumentationTest .this .onStats (tps );
99+ public void onStats (String stats ) {
100+ float tps = 0 ;
101+ try {
102+ JSONObject jsonObject = new JSONObject (stats );
103+ int numGeneratedTokens = jsonObject .getInt ("generated_tokens" );
104+ int inferenceEndMs = jsonObject .getInt ("inference_end_ms" );
105+ int promptEvalEndMs = jsonObject .getInt ("prompt_eval_end_ms" );
106+ tps = (float ) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs ) * 1000 ;
107+ LlmModuleInstrumentationTest .this .onStats (tps );
108+ } catch (JSONException e ) {
109+ }
99110 }
100111 });
101112
Original file line number Diff line number Diff line change @@ -31,8 +31,22 @@ public interface LlmCallback {
3131 /**
3232 * Called when the statistics for the generate() is available.
3333 *
34+ * Note: This is a deprecated API and will be removed in the future. Please use onStats(String stats)
35+ *
3436 * @param tps Tokens/second for generated tokens.
3537 */
38+ @ Deprecated
39+ @ DoNotStrip
40+ default public void onStats (float tps ) {}
41+
42+ /**
43+ * Called when the statistics for the generate() is available.
44+ *
45+ * The result will be a JSON string. See extension/llm/stats.h for the field
46+ * definitions.
47+ *
48+ * @param stats JSON string containing the statistics for the generate()
49+ */
3650 @ DoNotStrip
37- public void onStats (float tps );
51+ default public void onStats (String stats ) {}
3852}
Original file line number Diff line number Diff line change @@ -100,14 +100,20 @@ class ExecuTorchLlmCallbackJni
100100
101101 void onStats (const llm::Stats& result) const {
102102 static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic ();
103- static const auto method = cls->getMethod <void (jfloat)>(" onStats" );
103+ static const auto tps_method = cls->getMethod <void (jfloat)>(" onStats" );
104104 double eval_time =
105105 (double )(result.inference_end_ms - result.prompt_eval_end_ms );
106106
107107 float tps = result.num_generated_tokens / eval_time *
108108 result.SCALING_FACTOR_UNITS_PER_SECOND ;
109-
110- method (self (), tps);
109+ tps_method (self (), tps);
110+
111+ static const auto on_stats_method =
112+ cls->getMethod <void (facebook::jni::local_ref<jstring>)>(" onStats" );
113+ on_stats_method (
114+ self (),
115+ facebook::jni::make_jstring (
116+ executorch::extension::llm::stats_to_json_string (result)));
111117 }
112118};
113119
Original file line number Diff line number Diff line change @@ -39,6 +39,7 @@ dependencies {
3939 implementation(" com.facebook.soloader:soloader:0.10.5" )
4040 implementation(" com.facebook.fbjni:fbjni:0.5.1" )
4141 implementation(" com.google.code.gson:gson:2.8.6" )
42+ implementation(" org.json:json:20250107" )
4243 testImplementation(" junit:junit:4.13.2" )
4344 androidTestImplementation(" androidx.test.ext:junit:1.2.1" )
4445 androidTestImplementation(" androidx.test.espresso:espresso-core:3.6.1" )
You can’t perform that action at this time.
0 commit comments