Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions extension/benchmark/android/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench

### Generic model
```
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
--es model_dir /data/local/tmp/minibench
```

### LLM
```
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
--es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ phases:
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180

if [ -n "$BIN_FOUND" ]; then
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
elif [ -n "$MODEL_FOUND" ]; then
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model"
else
Expand Down
12 changes: 9 additions & 3 deletions extension/benchmark/android/benchmark/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

plugins { id("com.android.application") }
plugins { id("com.android.application")
id("org.jetbrains.kotlin.android")
}

android {
namespace = "org.pytorch.minibench"
Expand All @@ -29,8 +31,11 @@ android {
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlinOptions {
jvmTarget = "17"
}
}

Expand All @@ -40,6 +45,7 @@ dependencies {
implementation("com.facebook.fbjni:fbjni:0.5.1")
implementation("com.google.code.gson:gson:2.8.6")
implementation("org.json:json:20250107")
implementation("androidx.core:core-ktx:1.13.1")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@
</intent-filter>
</activity>

<activity
android:name=".LlmBenchmarkActivity"
android:exported="true">
<intent-filter>
<action android:name="org.pytorch.minibench.BENCHMARK" />
</intent-filter>
</activity>

</application>

</manifest>
Original file line number Diff line number Diff line change
Expand Up @@ -10,132 +10,118 @@

import android.app.Activity;
import android.content.Intent;
import android.os.AsyncTask;
import android.os.Bundle;
import android.os.Debug;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Looper;
import android.system.ErrnoException;
import android.system.Os;

import com.google.gson.Gson;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.pytorch.executorch.Module;

public class BenchmarkActivity extends Activity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);

try {
Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
} catch (ErrnoException e) {
finish();
}

Intent intent = getIntent();
File modelDir = new File(intent.getStringExtra("model_dir"));
File model =
Arrays.stream(modelDir.listFiles())
.filter(file -> file.getName().endsWith(".pte"))
.findFirst()
.get();

int numIter = intent.getIntExtra("num_iter", 50);
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
File mModel;
int mNumIter;
int mNumWarmupIter;
String mTokenizerPath;
float mTemperature;
String mPrompt;

long pssIdle = Debug.getPss();
HandlerThread mHandlerThread;
BenchmarkHandler mHandler;

// TODO: Format the string with a parsable format
Stats stats = new Stats();
List<BenchmarkMetric> mResult;

new AsyncTask<Void, Void, Void>() {
@Override
protected Void doInBackground(Void... voids) {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);

// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();

for (int i = 0; i < numWarmupIter; i++) {
module.forward();
try {
Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true);
} catch (ErrnoException e) {
finish();
}

for (int i = 0; i < numIter; i++) {
long start = System.nanoTime();
module.forward();
double forwardMs = (System.nanoTime() - start) * 1e-6;
stats.latency.add(forwardMs);
Intent intent = getIntent();
File modelDir = new File(intent.getStringExtra("model_dir"));
File model =
Arrays.stream(modelDir.listFiles())
.filter(file -> file.getName().endsWith(".pte"))
.findFirst()
.get();

int numIter = intent.getIntExtra("num_iter", 50);
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
String tokenizerPath = intent.getStringExtra("tokenizer_path");
float temperature = intent.getFloatExtra("temperature", 0.8f);
String prompt = intent.getStringExtra("prompt");

mModel = model;
mNumIter = numIter;
mNumWarmupIter = numWarmupIter;
mTokenizerPath = tokenizerPath;
mTemperature = temperature;
mPrompt = prompt;
if (mPrompt == null) {
mPrompt = "The ultimate answer";
}
return null;
}

@Override
protected void onPostExecute(Void aVoid) {

final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
// Currently the result has large variance from outliers, so only use
// 80% samples in the middle (trimmean 0.2)
Collections.sort(stats.latency);
int resultSize = stats.latency.size();
List<Double> usedLatencyResults =
stats.latency.subList(resultSize / 10, resultSize * 9 / 10);

results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
results.add(
new BenchmarkMetric(
benchmarkModel,
"trimmean_inference_latency(ms)",
usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ms)",
(stats.loadEnd - stats.loadStart) * 1e-6,
0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
// RAM PSS usage
results.add(
new BenchmarkMetric(
benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));
mResult = new ArrayList<>();

mHandlerThread = new HandlerThread("ModelRunner");
mHandlerThread.start();
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);

mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
}

void writeResult() {
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(results));
Gson gson = new Gson();
writer.write(gson.toJson(mResult));
} catch (IOException e) {
e.printStackTrace();
e.printStackTrace();
} finally {
finish();
}
}
}.execute();
}
}
}

class Stats {
long loadStart;
long loadEnd;
List<Double> latency = new ArrayList<>();
int errorCode = 0;
class BenchmarkHandler extends Handler {
public static int MESSAGE_RUN_BENCHMARK = 1;
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;

ModelRunner mModelRunner;
BenchmarkActivity mBenchmarkActivity;

@Override
public String toString() {
return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
}
LlmModelRunner mLlmModelRunner;
LlmBenchmark mLlmBenchmark;

public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
super(looper);
mModelRunner = new ModelRunner();
mBenchmarkActivity = benchmarkActivity;
}

@Override
public void handleMessage(android.os.Message msg) {
if (msg.what == MESSAGE_RUN_BENCHMARK) {
mModelRunner.runBenchmark(mBenchmarkActivity.mModel, mBenchmarkActivity.mNumWarmupIter, mBenchmarkActivity.mNumIter, mBenchmarkActivity.mResult);

if (mBenchmarkActivity.mTokenizerPath == null) {
mBenchmarkActivity.writeResult();
} else {
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
}
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
mLlmBenchmark = new LlmBenchmark(mBenchmarkActivity, mBenchmarkActivity.mModel.getPath(), mBenchmarkActivity.mTokenizerPath, mBenchmarkActivity.mPrompt, mBenchmarkActivity.mTemperature, mBenchmarkActivity.mResult);
}
}
}
Loading
Loading