Skip to content

Commit 40b3f02

Browse files
committed
Update some stuff
1 parent 553257b commit 40b3f02

File tree

6 files changed

+24
-5
lines changed

6 files changed

+24
-5
lines changed

examples/demo-apps/android/LlamaDemo/app/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ dependencies {
6060
implementation(files("libs/executorch.aar"))
6161
implementation("com.google.android.material:material:1.12.0")
6262
implementation("androidx.activity:activity:1.9.0")
63+
implementation("org.json:json:20250107")
6364
testImplementation("junit:junit:4.13.2")
6465
androidTestImplementation("androidx.test.ext:junit:1.1.5")
6566
androidTestImplementation("androidx.test.espresso:espresso-core:3.5.1")

examples/demo-apps/android/LlamaDemo/app/src/androidTest/java/com/example/executorchllamademo/PerfTest.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.ArrayList;
1919
import java.util.Arrays;
2020
import java.util.List;
21+
import org.json.JSONObject;
2122
import org.junit.Test;
2223
import org.junit.runner.RunWith;
2324
import org.pytorch.executorch.extension.llm.LlmCallback;
@@ -64,7 +65,12 @@ public void onResult(String result) {
6465
}
6566

6667
@Override
67-
public void onStats(float tps) {
68+
public void onStats(String stats) {
69+
JSONObject jsonObject = new JSONObject(stats);
70+
int numGeneratedTokens = jsonObject.getInt("num_generated_tokens");
71+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
72+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
73+
float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
6874
tokensPerSecond.add(tps);
6975
}
7076

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import java.util.List;
5050
import java.util.concurrent.Executor;
5151
import java.util.concurrent.Executors;
52+
import org.json.JSONObject;
5253
import org.pytorch.executorch.extension.llm.LlmCallback;
5354
import org.pytorch.executorch.extension.llm.LlmModule;
5455

@@ -97,10 +98,15 @@ public void onResult(String result) {
9798
}
9899

99100
@Override
100-
public void onStats(float tps) {
101+
public void onStats(String result) {
101102
runOnUiThread(
102103
() -> {
103104
if (mResultMessage != null) {
105+
JSONObject jsonObject = new JSONObject(stats);
106+
int numGeneratedTokens = jsonObject.getInt("num_generated_tokens");
107+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
108+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
109+
float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
104110
mResultMessage.setTokensPerSecond(tps);
105111
mMessageAdapter.notifyDataSetChanged();
106112
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunner.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import android.os.Looper;
1414
import android.os.Message;
1515
import androidx.annotation.NonNull;
16+
import org.json.JSONObject;
1617
import org.pytorch.executorch.extension.llm.LlmCallback;
1718
import org.pytorch.executorch.extension.llm.LlmModule;
1819

@@ -69,7 +70,12 @@ public void onResult(String result) {
6970
}
7071

7172
@Override
72-
public void onStats(float tps) {
73+
public void onStats(String stats) {
74+
JSONObject jsonObject = new JSONObject(stats);
75+
int numGeneratedTokens = jsonObject.getInt("num_generated_tokens");
76+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
77+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
78+
float tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
7379
mCallback.onStats("tokens/second: " + tps);
7480
}
7581
}

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelRunnerCallback.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public interface ModelRunnerCallback {
1818

1919
void onTokenGenerated(String token);
2020

21-
void onStats(String token);
21+
void onStats(String stats);
2222

2323
void onGenerationStopped();
2424
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public interface LlmCallback {
3737
*/
3838
@Deprecated
3939
@DoNotStrip
40-
public void onStats(float tps);
40+
default public void onStats(float tps) {}
4141

4242
/**
4343
* Called when the statistics for the generate() is available.

0 commit comments

Comments
 (0)