Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,13 @@ public enum DType {
DType(int jniCode) {
this.jniCode = jniCode;
}

public static DType fromJniCode(int jniCode) {
for (DType dtype : values()) {
if (dtype.jniCode == jniCode) {
return dtype;
}
}
throw new IllegalArgumentException("No DType found for jniCode " + jniCode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.pytorch.executorch;

import android.util.Log;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import java.nio.Buffer;
Expand Down Expand Up @@ -630,6 +631,31 @@ public String toString() {
}
}

static class Tensor_unsupported extends Tensor {
private final ByteBuffer data;
private final DType myDtype;

private Tensor_unsupported(ByteBuffer data, long[] shape, DType dtype) {
super(shape);
this.data = data;
this.myDtype = dtype;
Log.e(
"ExecuTorch",
toString() + " in Java. Please consider re-export the model with proper return type");
}

@Override
public DType dtype() {
return myDtype;
}

@Override
public String toString() {
return String.format(
"Unsupported tensor(%s, dtype=%d)", Arrays.toString(shape), this.myDtype);
}
}

// region checks
private static void checkArgument(boolean expression, String errorMessage, Object... args) {
if (!expression) {
Expand Down Expand Up @@ -675,7 +701,7 @@ private static Tensor nativeNewTensor(
} else if (DType.INT8.jniCode == dtype) {
tensor = new Tensor_int8(data, shape);
} else {
throw new IllegalArgumentException("Unknown Tensor dtype");
tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype));
}
tensor.mHybridData = hybridData;
return tensor;
Expand Down
4 changes: 2 additions & 2 deletions extension/benchmark/android/benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench

### Generic model
```
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
--es model_dir /data/local/tmp/minibench
```

### LLM
```
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \
adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \
--es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ phases:
adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180

if [ -n "$BIN_FOUND" ]; then
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
elif [ -n "$MODEL_FOUND" ]; then
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
--es "model_dir" "/data/local/tmp/minibench" \
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model"
else
Expand Down
12 changes: 9 additions & 3 deletions extension/benchmark/android/benchmark/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/

plugins { id("com.android.application") }
plugins { id("com.android.application")
id("org.jetbrains.kotlin.android")
}

android {
namespace = "org.pytorch.minibench"
Expand All @@ -29,8 +31,11 @@ android {
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_1_8
targetCompatibility = JavaVersion.VERSION_1_8
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlinOptions {
jvmTarget = "17"
}
}

Expand All @@ -40,6 +45,7 @@ dependencies {
implementation("com.facebook.fbjni:fbjni:0.5.1")
implementation("com.google.code.gson:gson:2.8.6")
implementation("org.json:json:20250107")
implementation("androidx.core:core-ktx:1.13.1")
testImplementation("junit:junit:4.13.2")
androidTestImplementation("androidx.test.ext:junit:1.2.1")
androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,6 @@
</intent-filter>
</activity>

<activity
android:name=".LlmBenchmarkActivity"
android:exported="true">
<intent-filter>
<action android:name="org.pytorch.minibench.BENCHMARK" />
</intent-filter>
</activity>

</application>

</manifest>
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

import android.app.Activity;
import android.content.Intent;
import android.os.AsyncTask;
import android.os.Bundle;
import android.os.Debug;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.Looper;
import android.system.ErrnoException;
import android.system.Os;
import com.google.gson.Gson;
Expand All @@ -21,12 +22,22 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.pytorch.executorch.Module;

public class BenchmarkActivity extends Activity {

File mModel;
int mNumIter;
int mNumWarmupIter;
String mTokenizerPath;
float mTemperature;
String mPrompt;

HandlerThread mHandlerThread;
BenchmarkHandler mHandler;

List<BenchmarkMetric> mResult;

@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
Expand All @@ -47,95 +58,79 @@ protected void onCreate(Bundle savedInstanceState) {

int numIter = intent.getIntExtra("num_iter", 50);
int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10);
String tokenizerPath = intent.getStringExtra("tokenizer_path");
float temperature = intent.getFloatExtra("temperature", 0.8f);
String prompt = intent.getStringExtra("prompt");

mModel = model;
mNumIter = numIter;
mNumWarmupIter = numWarmupIter;
mTokenizerPath = tokenizerPath;
mTemperature = temperature;
mPrompt = prompt;
if (mPrompt == null) {
mPrompt = "The ultimate answer";
}
mResult = new ArrayList<>();

long pssIdle = Debug.getPss();
mHandlerThread = new HandlerThread("ModelRunner");
mHandlerThread.start();
mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this);

// TODO: Format the string with a parsable format
Stats stats = new Stats();
mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK);
}

new AsyncTask<Void, Void, Void>() {
@Override
protected Void doInBackground(Void... voids) {
void writeResult() {
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mResult));
} catch (IOException e) {
e.printStackTrace();
} finally {
finish();
}
}
}

// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();
class BenchmarkHandler extends Handler {
public static int MESSAGE_RUN_BENCHMARK = 1;
public static int MESSAGE_LLM_RUN_BENCHMARK = 2;

for (int i = 0; i < numWarmupIter; i++) {
module.forward();
}
ModelRunner mModelRunner;
BenchmarkActivity mBenchmarkActivity;

for (int i = 0; i < numIter; i++) {
long start = System.nanoTime();
module.forward();
double forwardMs = (System.nanoTime() - start) * 1e-6;
stats.latency.add(forwardMs);
}
return null;
}
LlmModelRunner mLlmModelRunner;
LlmBenchmark mLlmBenchmark;

@Override
protected void onPostExecute(Void aVoid) {

final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
// Currently the result has large variance from outliers, so only use
// 80% samples in the middle (trimmean 0.2)
Collections.sort(stats.latency);
int resultSize = stats.latency.size();
List<Double> usedLatencyResults =
stats.latency.subList(resultSize / 10, resultSize * 9 / 10);

results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
results.add(
new BenchmarkMetric(
benchmarkModel,
"trimmean_inference_latency(ms)",
usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ms)",
(stats.loadEnd - stats.loadStart) * 1e-6,
0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
// RAM PSS usage
results.add(
new BenchmarkMetric(
benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0));

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}
}.execute();
public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) {
super(looper);
mModelRunner = new ModelRunner();
mBenchmarkActivity = benchmarkActivity;
}
}

class Stats {
long loadStart;
long loadEnd;
List<Double> latency = new ArrayList<>();
int errorCode = 0;

@Override
public String toString() {
return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
public void handleMessage(android.os.Message msg) {
if (msg.what == MESSAGE_RUN_BENCHMARK) {
mModelRunner.runBenchmark(
mBenchmarkActivity.mModel,
mBenchmarkActivity.mNumWarmupIter,
mBenchmarkActivity.mNumIter,
mBenchmarkActivity.mResult);

if (mBenchmarkActivity.mTokenizerPath == null) {
mBenchmarkActivity.writeResult();
} else {
this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK);
}
} else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) {
mLlmBenchmark =
new LlmBenchmark(
mBenchmarkActivity,
mBenchmarkActivity.mModel.getPath(),
mBenchmarkActivity.mTokenizerPath,
mBenchmarkActivity.mPrompt,
mBenchmarkActivity.mTemperature,
mBenchmarkActivity.mResult);
}
}
}
Loading
Loading