Skip to content

Commit 440048c

Browse files
authored
Add an activity for benchmarking only
Differential Revision: D60399589 Pull Request resolved: #4443
1 parent 0c26dc0 commit 440048c

File tree

5 files changed

+258
-0
lines changed

5 files changed

+258
-0
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/AndroidManifest.xml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@
4747
<category android:name="android.intent.category.LAUNCHER" />
4848
</intent-filter>
4949
</activity>
50+
51+
<activity
52+
android:name=".LlmBenchmarkRunner"
53+
android:exported="true">
54+
<intent-filter>
55+
<action android:name="com.example.executorchllamademo.BENCHMARK" />
56+
</intent-filter>
57+
</activity>
58+
5059
</application>
5160

5261
</manifest>
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo;
10+
11+
import android.app.Activity;
12+
import android.content.Intent;
13+
import android.os.Bundle;
14+
import android.util.Log;
15+
import android.widget.TextView;
16+
import androidx.annotation.NonNull;
17+
import java.io.FileWriter;
18+
import java.io.IOException;
19+
20+
public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
21+
ModelRunner mModelRunner;
22+
23+
String mPrompt;
24+
TextView mTextView;
25+
StatsDump mStatsDump;
26+
27+
@Override
28+
protected void onCreate(Bundle savedInstanceState) {
29+
super.onCreate(savedInstanceState);
30+
setContentView(R.layout.activity_benchmarking);
31+
mTextView = findViewById(R.id.log_view);
32+
33+
Intent intent = getIntent();
34+
35+
String modelPath = intent.getStringExtra("model_path");
36+
String tokenizerPath = intent.getStringExtra("tokenizer_path");
37+
38+
float temperature = intent.getFloatExtra("temperature", 0.8f);
39+
mPrompt = intent.getStringExtra("prompt");
40+
if (mPrompt == null) {
41+
mPrompt = "The ultimate answer";
42+
}
43+
44+
mStatsDump = new StatsDump();
45+
mModelRunner = new ModelRunner(modelPath, tokenizerPath, temperature, this);
46+
mStatsDump.loadStart = System.currentTimeMillis();
47+
}
48+
49+
@Override
50+
public void onModelLoaded(int status) {
51+
mStatsDump.loadEnd = System.currentTimeMillis();
52+
if (status != 0) {
53+
Log.e("LlmBenchmarkRunner", "Loaded failed: " + status);
54+
onGenerationStopped();
55+
return;
56+
}
57+
mStatsDump.generateStart = System.currentTimeMillis();
58+
mModelRunner.generate(mPrompt);
59+
}
60+
61+
@Override
62+
public void onTokenGenerated(String token) {
63+
runOnUiThread(
64+
() -> {
65+
mTextView.append(token);
66+
});
67+
}
68+
69+
@Override
70+
public void onStats(String stats) {
71+
mStatsDump.tokens = stats;
72+
}
73+
74+
@Override
75+
public void onGenerationStopped() {
76+
mStatsDump.generateEnd = System.currentTimeMillis();
77+
runOnUiThread(
78+
() -> {
79+
mTextView.append(mStatsDump.toString());
80+
});
81+
82+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
83+
writer.write(mStatsDump.toString());
84+
} catch (IOException e) {
85+
e.printStackTrace();
86+
}
87+
}
88+
}
89+
90+
class StatsDump {
91+
long loadStart;
92+
long loadEnd;
93+
long generateStart;
94+
long generateEnd;
95+
String tokens;
96+
97+
@NonNull
98+
@Override
99+
public String toString() {
100+
return "loadStart: "
101+
+ loadStart
102+
+ "\nloadEnd: "
103+
+ loadEnd
104+
+ "\ngenerateStart: "
105+
+ generateStart
106+
+ "\ngenerateEnd: "
107+
+ generateEnd
108+
+ "\n"
109+
+ tokens;
110+
}
111+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo;
10+
11+
import android.os.Handler;
12+
import android.os.HandlerThread;
13+
import android.os.Looper;
14+
import android.os.Message;
15+
import androidx.annotation.NonNull;
16+
import org.pytorch.executorch.LlamaCallback;
17+
import org.pytorch.executorch.LlamaModule;
18+
19+
/** A helper class to handle all model running logic within this class. */
20+
public class ModelRunner implements LlamaCallback {
21+
LlamaModule mModule = null;
22+
23+
String mModelFilePath = "";
24+
String mTokenizerFilePath = "";
25+
26+
ModelRunnerCallback mCallback = null;
27+
28+
HandlerThread mHandlerThread = null;
29+
Handler mHandler = null;
30+
31+
/**
32+
* ] Helper class to separate between UI logic and model runner logic. Automatically handle
33+
* generate() request on worker thread.
34+
*
35+
* @param modelFilePath
36+
* @param tokenizerFilePath
37+
* @param callback
38+
*/
39+
ModelRunner(
40+
String modelFilePath,
41+
String tokenizerFilePath,
42+
float temperature,
43+
ModelRunnerCallback callback) {
44+
mModelFilePath = modelFilePath;
45+
mTokenizerFilePath = tokenizerFilePath;
46+
mCallback = callback;
47+
48+
mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f);
49+
mHandlerThread = new HandlerThread("ModelRunner");
50+
mHandlerThread.start();
51+
mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this);
52+
53+
mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL);
54+
}
55+
56+
int generate(String prompt) {
57+
Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt);
58+
msg.sendToTarget();
59+
return 0;
60+
}
61+
62+
void stop() {
63+
mModule.stop();
64+
}
65+
66+
@Override
67+
public void onResult(String result) {
68+
mCallback.onTokenGenerated(result);
69+
}
70+
71+
@Override
72+
public void onStats(float tps) {
73+
mCallback.onStats("tokens/second: " + tps);
74+
}
75+
}
76+
77+
class ModelRunnerHandler extends Handler {
78+
public static int MESSAGE_LOAD_MODEL = 1;
79+
public static int MESSAGE_GENERATE = 2;
80+
81+
private final ModelRunner mModelRunner;
82+
83+
public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) {
84+
super(looper);
85+
mModelRunner = modelRunner;
86+
}
87+
88+
@Override
89+
public void handleMessage(@NonNull android.os.Message msg) {
90+
if (msg.what == MESSAGE_LOAD_MODEL) {
91+
int status = mModelRunner.mModule.load();
92+
mModelRunner.mCallback.onModelLoaded(status);
93+
} else if (msg.what == MESSAGE_GENERATE) {
94+
mModelRunner.mModule.generate((String) msg.obj, mModelRunner);
95+
mModelRunner.mCallback.onGenerationStopped();
96+
}
97+
}
98+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo;
10+
11+
/**
12+
* A helper interface within the app for MainActivity and Benchmarking to handle callback from
13+
* ModelRunner.
14+
*/
15+
public interface ModelRunnerCallback {
16+
17+
void onModelLoaded(int status);
18+
19+
void onTokenGenerated(String token);
20+
21+
void onStats(String token);
22+
23+
void onGenerationStopped();
24+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
3+
xmlns:tools="http://schemas.android.com/tools"
4+
android:layout_width="match_parent"
5+
android:layout_height="match_parent"
6+
android:orientation="vertical"
7+
android:clipToPadding="false"
8+
android:focusableInTouchMode="true"
9+
tools:context=".LlmBenchmarkRunner">
10+
11+
<TextView
12+
android:layout_width="match_parent"
13+
android:layout_height="match_parent"
14+
android:id="@+id/log_view" />
15+
16+
</LinearLayout>

0 commit comments

Comments
 (0)