diff --git a/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts b/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts index ea9d4e6c172..893b1ee4784 100644 --- a/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts +++ b/examples/demo-apps/android/LlamaDemo/app/build.gradle.kts @@ -60,6 +60,7 @@ dependencies { implementation(files("libs/executorch.aar")) implementation("com.google.android.material:material:1.12.0") implementation("androidx.activity:activity:1.9.0") + implementation("org.json:json:20250107") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.1.5") androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1") diff --git a/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java b/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java index 21ac285d3b0..32ec24a0df9 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java @@ -18,6 +18,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.json.JSONException; +import org.json.JSONObject; import org.junit.Test; import org.junit.runner.RunWith; import org.pytorch.executorch.extension.llm.LlmCallback; @@ -64,8 +66,16 @@ public void onResult(String result) { } @Override - public void onStats(float tps) { - tokensPerSecond.add(tps); + public void onStats(String result) { + try { + JSONObject jsonObject = new JSONObject(result); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + tokensPerSecond.add(tps); + } catch (JSONException e) { + } } private void report(final String metric, final Float value) { diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index e19155b83e8..137e01f8f43 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -49,6 +49,8 @@ import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import org.json.JSONException; +import org.json.JSONObject; import org.pytorch.executorch.extension.llm.LlmCallback; import org.pytorch.executorch.extension.llm.LlmModule; @@ -97,10 +99,20 @@ public void onResult(String result) { } @Override - public void onStats(float tps) { + public void onStats(String stats) { runOnUiThread( () -> { if (mResultMessage != null) { + float tps = 0; + try { + JSONObject jsonObject = new JSONObject(stats); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + } catch (JSONException e) { + Log.e("LLM", "Error parsing JSON: " + e.getMessage()); + } mResultMessage.setTokensPerSecond(tps); mMessageAdapter.notifyDataSetChanged(); } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java index 78cfee993c4..a1bc205c4ac 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java @@ -13,6 +13,8 @@ import android.os.Looper; import android.os.Message; import androidx.annotation.NonNull; +import org.json.JSONException; +import org.json.JSONObject; import org.pytorch.executorch.extension.llm.LlmCallback; import org.pytorch.executorch.extension.llm.LlmModule; @@ -69,7 +71,16 @@ public void onResult(String result) { } @Override - public void onStats(float tps) { + public void onStats(String stats) { + float tps = 0; + try { + JSONObject jsonObject = new JSONObject(stats); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + } catch (JSONException e) { + } mCallback.onStats("tokens/second: " + tps); } } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java index c8bdc53075e..5e8b6f00e3d 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java @@ -18,7 +18,7 @@ public interface ModelRunnerCallback { void onTokenGenerated(String token); - void onStats(String token); + void onStats(String stats); void onGenerationStopped(); } diff --git a/extension/android/executorch_android/build.gradle b/extension/android/executorch_android/build.gradle index b1bc090759a..15088f4097f 100644 --- a/extension/android/executorch_android/build.gradle +++ b/extension/android/executorch_android/build.gradle @@ -47,6 +47,7 @@ dependencies { androidTestImplementation 'androidx.test.ext:junit:1.1.5' androidTestImplementation 'androidx.test:rules:1.2.0' androidTestImplementation 'commons-io:commons-io:2.4' + androidTestImplementation 'org.json:json:20250107' } import com.vanniktech.maven.publish.SonatypeHost diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java index e95b42f2650..3791938c3bc 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java @@ -34,6 +34,8 @@ import org.apache.commons.io.FileUtils; import androidx.test.ext.junit.runners.AndroidJUnit4; import androidx.test.InstrumentationRegistry; +import org.json.JSONException; +import org.json.JSONObject; import org.pytorch.executorch.extension.llm.LlmCallback; import org.pytorch.executorch.extension.llm.LlmModule; @@ -94,8 +96,17 @@ public void onResult(String result) { } @Override - public void onStats(float tps) { - LlmModuleInstrumentationTest.this.onStats(tps); + public void onStats(String stats) { + float tps = 0; + try { + JSONObject jsonObject = new JSONObject(stats); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + LlmModuleInstrumentationTest.this.onStats(tps); + } catch (JSONException e) { + } } }); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java index c05b30b0625..a829cf75dc4 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java @@ -31,8 +31,22 @@ public interface LlmCallback { /** * Called when the statistics for the generate() is available. * + * Note: This is a deprecated API and will be removed in the future. Please use onStats(String stats) + * * @param tps Tokens/second for generated tokens. */ + @Deprecated + @DoNotStrip + default public void onStats(float tps) {} + + /** + * Called when the statistics for the generate() is available. + * + * The result will be a JSON string. See extension/llm/stats.h for the field + * definitions. + * + * @param stats JSON string containing the statistics for the generate() + */ @DoNotStrip - public void onStats(float tps); + default public void onStats(String stats) {} } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index edb67bdabea..4a98ccb7c82 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -100,14 +100,20 @@ class ExecuTorchLlmCallbackJni void onStats(const llm::Stats& result) const { static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); - static const auto method = cls->getMethod("onStats"); + static const auto tps_method = cls->getMethod("onStats"); double eval_time = (double)(result.inference_end_ms - result.prompt_eval_end_ms); float tps = result.num_generated_tokens / eval_time * result.SCALING_FACTOR_UNITS_PER_SECOND; - - method(self(), tps); + tps_method(self(), tps); + + static const auto on_stats_method = + cls->getMethod)>("onStats"); + on_stats_method( + self(), + facebook::jni::make_jstring( + executorch::extension::llm::stats_to_json_string(result))); } }; diff --git a/extension/benchmark/android/benchmark/app/build.gradle.kts b/extension/benchmark/android/benchmark/app/build.gradle.kts index dcf99ca9cd0..28dfc8ae49d 100644 --- a/extension/benchmark/android/benchmark/app/build.gradle.kts +++ b/extension/benchmark/android/benchmark/app/build.gradle.kts @@ -39,6 +39,7 @@ dependencies { implementation("com.facebook.soloader:soloader:0.10.5") implementation("com.facebook.fbjni:fbjni:0.5.1") implementation("com.google.code.gson:gson:2.8.6") + implementation("org.json:json:20250107") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1") diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java index 3bc38aad403..8db2e8633ad 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java @@ -23,6 +23,8 @@ import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.json.JSONException; +import org.json.JSONObject; public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { ModelRunner mModelRunner; @@ -80,7 +82,17 @@ public void onTokenGenerated(String token) {} @Override public void onStats(String stats) { - mStatsInfo.tokens = stats; + float tps = 0; + try { + JSONObject jsonObject = new JSONObject(stats); + int numGeneratedTokens = jsonObject.getInt("generated_tokens"); + int inferenceEndMs = jsonObject.getInt("inference_end_ms"); + int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms"); + tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000; + mStatsInfo.tps = tps; + } catch (JSONException e) { + Log.e("LLM", "Error parsing JSON: " + e.getMessage()); + } } @Override @@ -109,7 +121,7 @@ public void onGenerationStopped() { 0.0f)); // Token per second results.add( - new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsInfo.tokens), 0.0f)); + new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { Gson gson = new Gson(); @@ -118,15 +130,6 @@ public void onGenerationStopped() { e.printStackTrace(); } } - - private double extractTPS(final String tokens) { - final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens); - if (m.find()) { - return Double.parseDouble(m.group()); - } else { - return 0.0f; - } - } } class StatsInfo { @@ -135,7 +138,7 @@ class StatsInfo { long loadEnd; long generateStart; long generateEnd; - String tokens; + float tps; String modelName; @Override @@ -149,6 +152,6 @@ public String toString() { + "\ngenerateEnd: " + generateEnd + "\n" - + tokens; + + tps; } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 6ba1f57c4f3..0a75b47f3a6 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -68,8 +68,8 @@ public void onResult(String result) { } @Override - public void onStats(float tps) { - mCallback.onStats("tokens/second: " + tps); + public void onStats(String result) { + mCallback.onStats(result); } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java index 63701a7bbc6..8503d47ccce 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java @@ -18,7 +18,7 @@ public interface ModelRunnerCallback { void onTokenGenerated(String token); - void onStats(String token); + void onStats(String result); void onGenerationStopped(); }