diff --git a/extension/benchmark/android/benchmark/README.md b/extension/benchmark/android/benchmark/README.md index a5cdd227746..f6731023f47 100644 --- a/extension/benchmark/android/benchmark/README.md +++ b/extension/benchmark/android/benchmark/README.md @@ -43,13 +43,13 @@ adb push tokenizer.bin /data/local/tmp/minibench ### Generic model ``` -adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \ +adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \ --es model_dir /data/local/tmp/minibench ``` ### LLM ``` -adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.BenchmarkActivity \ +adb shell am start -W -S -n org.pytorch.minibench/org.pytorch.minibench.LlmBenchmarkActivity \ --es model_dir /data/local/tmp/minibench --es tokenizer_path /data/local/tmp/minibench/tokenizer.bin ``` diff --git a/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 b/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 index 7d668e90c84..4f8e72d21bc 100644 --- a/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 +++ b/extension/benchmark/android/benchmark/android-llm-device-farm-test-spec.yml.j2 @@ -114,11 +114,11 @@ phases: adb -s $DEVICEFARM_DEVICE_UDID shell sleep 180 if [ -n "$BIN_FOUND" ]; then - adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \ + adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \ --es "model_dir" "/data/local/tmp/minibench" \ --es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin" elif [ -n "$MODEL_FOUND" ]; then - adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \ + adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \ --es "model_dir" "/data/local/tmp/minibench" \ --es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.model" else diff --git a/extension/benchmark/android/benchmark/app/build.gradle.kts b/extension/benchmark/android/benchmark/app/build.gradle.kts index 4ee7efd1f97..28dfc8ae49d 100644 --- a/extension/benchmark/android/benchmark/app/build.gradle.kts +++ b/extension/benchmark/android/benchmark/app/build.gradle.kts @@ -6,9 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -plugins { id("com.android.application") - id("org.jetbrains.kotlin.android") -} +plugins { id("com.android.application") } android { namespace = "org.pytorch.minibench" @@ -31,11 +29,8 @@ android { } } compileOptions { - sourceCompatibility = JavaVersion.VERSION_17 - targetCompatibility = JavaVersion.VERSION_17 - } - kotlinOptions { - jvmTarget = "17" + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 } } @@ -45,7 +40,6 @@ dependencies { implementation("com.facebook.fbjni:fbjni:0.5.1") implementation("com.google.code.gson:gson:2.8.6") implementation("org.json:json:20250107") - implementation("androidx.core:core-ktx:1.13.1") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1") diff --git a/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml b/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml index 723829de981..7f62c509d55 100644 --- a/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml +++ b/extension/benchmark/android/benchmark/app/src/main/AndroidManifest.xml @@ -21,6 +21,14 @@ + + + + + + diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index 5e1dd48926b..78830d5a54d 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -10,10 +10,9 @@ import android.app.Activity; import android.content.Intent; +import android.os.AsyncTask; import android.os.Bundle; -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; +import android.os.Debug; import android.system.ErrnoException; import android.system.Os; import com.google.gson.Gson; @@ -22,22 +21,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; +import org.pytorch.executorch.Module; public class BenchmarkActivity extends Activity { - - File mModel; - int mNumIter; - int mNumWarmupIter; - String mTokenizerPath; - float mTemperature; - String mPrompt; - - HandlerThread mHandlerThread; - BenchmarkHandler mHandler; - - List mResult; - @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -58,79 +47,95 @@ protected void onCreate(Bundle savedInstanceState) { int numIter = intent.getIntExtra("num_iter", 50); int numWarmupIter = intent.getIntExtra("num_warm_up_iter", 10); - String tokenizerPath = intent.getStringExtra("tokenizer_path"); - float temperature = intent.getFloatExtra("temperature", 0.8f); - String prompt = intent.getStringExtra("prompt"); - - mModel = model; - mNumIter = numIter; - mNumWarmupIter = numWarmupIter; - mTokenizerPath = tokenizerPath; - mTemperature = temperature; - mPrompt = prompt; - if (mPrompt == null) { - mPrompt = "The ultimate answer"; - } - mResult = new ArrayList<>(); - mHandlerThread = new HandlerThread("ModelRunner"); - mHandlerThread.start(); - mHandler = new BenchmarkHandler(mHandlerThread.getLooper(), this); + long pssIdle = Debug.getPss(); - mHandler.sendEmptyMessage(BenchmarkHandler.MESSAGE_RUN_BENCHMARK); - } + // TODO: Format the string with a parsable format + Stats stats = new Stats(); - void writeResult() { - try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { - Gson gson = new Gson(); - writer.write(gson.toJson(mResult)); - } catch (IOException e) { - e.printStackTrace(); - } finally { - finish(); - } - } -} + new AsyncTask() { + @Override + protected Void doInBackground(Void... voids) { -class BenchmarkHandler extends Handler { - public static int MESSAGE_RUN_BENCHMARK = 1; - public static int MESSAGE_LLM_RUN_BENCHMARK = 2; + // Record the time it takes to load the model and the forward method + stats.loadStart = System.nanoTime(); + Module module = Module.load(model.getPath()); + stats.errorCode = module.loadMethod("forward"); + stats.loadEnd = System.nanoTime(); - ModelRunner mModelRunner; - BenchmarkActivity mBenchmarkActivity; + for (int i = 0; i < numWarmupIter; i++) { + module.forward(); + } - LlmModelRunner mLlmModelRunner; - LlmBenchmark mLlmBenchmark; + for (int i = 0; i < numIter; i++) { + long start = System.nanoTime(); + module.forward(); + double forwardMs = (System.nanoTime() - start) * 1e-6; + stats.latency.add(forwardMs); + } + return null; + } - public BenchmarkHandler(Looper looper, BenchmarkActivity benchmarkActivity) { - super(looper); - mModelRunner = new ModelRunner(); - mBenchmarkActivity = benchmarkActivity; + @Override + protected void onPostExecute(Void aVoid) { + + final BenchmarkMetric.BenchmarkModel benchmarkModel = + BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); + final List results = new ArrayList<>(); + // The list of metrics we have atm includes: + // Avg inference latency after N iterations + // Currently the result has large variance from outliers, so only use + // 80% samples in the middle (trimmean 0.2) + Collections.sort(stats.latency); + int resultSize = stats.latency.size(); + List usedLatencyResults = + stats.latency.subList(resultSize / 10, resultSize * 9 / 10); + + results.add( + new BenchmarkMetric( + benchmarkModel, + "avg_inference_latency(ms)", + stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + results.add( + new BenchmarkMetric( + benchmarkModel, + "trimmean_inference_latency(ms)", + usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), + 0.0f)); + // Model load time + results.add( + new BenchmarkMetric( + benchmarkModel, + "model_load_time(ms)", + (stats.loadEnd - stats.loadStart) * 1e-6, + 0.0f)); + // Load status + results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); + // RAM PSS usage + results.add( + new BenchmarkMetric( + benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); + + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(results)); + } catch (IOException e) { + e.printStackTrace(); + } + } + }.execute(); } +} + +class Stats { + long loadStart; + long loadEnd; + List latency = new ArrayList<>(); + int errorCode = 0; @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_RUN_BENCHMARK) { - mModelRunner.runBenchmark( - mBenchmarkActivity.mModel, - mBenchmarkActivity.mNumWarmupIter, - mBenchmarkActivity.mNumIter, - mBenchmarkActivity.mResult); - - if (mBenchmarkActivity.mTokenizerPath == null) { - mBenchmarkActivity.writeResult(); - } else { - this.sendEmptyMessage(MESSAGE_LLM_RUN_BENCHMARK); - } - } else if (msg.what == MESSAGE_LLM_RUN_BENCHMARK) { - mLlmBenchmark = - new LlmBenchmark( - mBenchmarkActivity, - mBenchmarkActivity.mModel.getPath(), - mBenchmarkActivity.mTokenizerPath, - mBenchmarkActivity.mPrompt, - mBenchmarkActivity.mTemperature, - mBenchmarkActivity.mResult); - } + public String toString() { + return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining("")); } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java similarity index 57% rename from extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java rename to extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java index 0c0436d2676..f6a894d6a1f 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmark.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java @@ -8,33 +8,57 @@ package org.pytorch.minibench; +import android.app.Activity; +import android.content.Intent; +import android.os.Bundle; +import android.system.ErrnoException; +import android.system.Os; import android.util.Log; +import com.google.gson.Gson; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.json.JSONException; import org.json.JSONObject; -public class LlmBenchmark implements LlmModelRunnerCallback { - LlmModelRunner mLlmModelRunner; +public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { + ModelRunner mModelRunner; String mPrompt; StatsInfo mStatsInfo; - List mResults; - BenchmarkActivity mActivity; - - LlmBenchmark( - BenchmarkActivity activity, - String modelFile, - String tokenizerPath, - String prompt, - float temperature, - List results) { - mResults = results; - mActivity = activity; + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + Intent intent = getIntent(); + + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + + float temperature = intent.getFloatExtra("temperature", 0.8f); + mPrompt = intent.getStringExtra("prompt"); + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + + try { + Os.setenv("ADSP_LIBRARY_PATH", getApplicationInfo().nativeLibraryDir, true); + } catch (ErrnoException e) { + finish(); + } + mStatsInfo = new StatsInfo(); - mStatsInfo.modelName = modelFile.substring(modelFile.lastIndexOf('/') + 1).replace(".pte", ""); - mPrompt = prompt; - mLlmModelRunner = new LlmModelRunner(modelFile, tokenizerPath, temperature, this); + mStatsInfo.modelName = model.getName().replace(".pte", ""); + mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); mStatsInfo.loadStart = System.nanoTime(); } @@ -48,7 +72,7 @@ public void onModelLoaded(int status) { return; } mStatsInfo.generateStart = System.nanoTime(); - mLlmModelRunner.generate(mPrompt); + mModelRunner.generate(mPrompt); } @Override @@ -75,26 +99,33 @@ public void onGenerationStopped() { final BenchmarkMetric.BenchmarkModel benchmarkModel = BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.modelName); + final List results = new ArrayList<>(); // The list of metrics we have atm includes: // Load status - mResults.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); + results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsInfo.loadStatus, 0)); // Model load time - mResults.add( + results.add( new BenchmarkMetric( benchmarkModel, - "llm_model_load_time(ms)", + "model_load_time(ms)", (mStatsInfo.loadEnd - mStatsInfo.loadStart) * 1e-6, 0.0f)); // LLM generate time - mResults.add( + results.add( new BenchmarkMetric( benchmarkModel, "generate_time(ms)", (mStatsInfo.generateEnd - mStatsInfo.generateStart) * 1e-6, 0.0f)); // Token per second - mResults.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); - mActivity.writeResult(); + results.add(new BenchmarkMetric(benchmarkModel, "token_per_sec", mStatsInfo.tps, 0.0f)); + + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(results)); + } catch (IOException e) { + e.printStackTrace(); + } } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java deleted file mode 100644 index a1b434a37bf..00000000000 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunner.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.minibench; - -import android.os.Handler; -import android.os.HandlerThread; -import android.os.Looper; -import android.os.Message; -import org.pytorch.executorch.extension.llm.LlmCallback; -import org.pytorch.executorch.extension.llm.LlmModule; - -/** A helper class to handle all model running logic within this class. */ -public class LlmModelRunner implements LlmCallback { - LlmModule mModule = null; - - String mModelFilePath = ""; - String mTokenizerFilePath = ""; - - LlmModelRunnerCallback mCallback = null; - - HandlerThread mHandlerThread = null; - Handler mHandler = null; - - /** - * ] Helper class to separate between UI logic and model runner logic. Automatically handle - * generate() request on worker thread. - * - * @param modelFilePath - * @param tokenizerFilePath - * @param callback - */ - LlmModelRunner( - String modelFilePath, - String tokenizerFilePath, - float temperature, - LlmModelRunnerCallback callback) { - mModelFilePath = modelFilePath; - mTokenizerFilePath = tokenizerFilePath; - mCallback = callback; - - mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); - mHandlerThread = new HandlerThread("LlmModelRunner"); - mHandlerThread.start(); - mHandler = new LlmModelRunnerHandler(mHandlerThread.getLooper(), this); - - mHandler.sendEmptyMessage(LlmModelRunnerHandler.MESSAGE_LOAD_MODEL); - } - - int generate(String prompt) { - Message msg = Message.obtain(mHandler, LlmModelRunnerHandler.MESSAGE_GENERATE, prompt); - msg.sendToTarget(); - return 0; - } - - void stop() { - mModule.stop(); - } - - @Override - public void onResult(String result) { - mCallback.onTokenGenerated(result); - } - - @Override - public void onStats(String result) { - mCallback.onStats(result); - } -} - -class LlmModelRunnerHandler extends Handler { - public static int MESSAGE_LOAD_MODEL = 1; - public static int MESSAGE_GENERATE = 2; - - private final LlmModelRunner mLlmModelRunner; - - public LlmModelRunnerHandler(Looper looper, LlmModelRunner llmModelRunner) { - super(looper); - mLlmModelRunner = llmModelRunner; - } - - @Override - public void handleMessage(android.os.Message msg) { - if (msg.what == MESSAGE_LOAD_MODEL) { - int status = mLlmModelRunner.mModule.load(); - mLlmModelRunner.mCallback.onModelLoaded(status); - } else if (msg.what == MESSAGE_GENERATE) { - mLlmModelRunner.mModule.generate((String) msg.obj, mLlmModelRunner); - mLlmModelRunner.mCallback.onGenerationStopped(); - } - } -} diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java index 3913a8d76f5..0a75b47f3a6 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -8,70 +8,90 @@ package org.pytorch.minibench; -import android.os.Debug; -import java.io.File; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import org.pytorch.executorch.Module; - -public class ModelRunner { +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; +import org.pytorch.executorch.extension.llm.LlmCallback; +import org.pytorch.executorch.extension.llm.LlmModule; + +/** A helper class to handle all model running logic within this class. */ +public class ModelRunner implements LlmCallback { + LlmModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + ModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + /** - * @return list of #BenchmarkMetric + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback */ - public void runBenchmark( - File model, int numWarmupIter, int numIter, List results) { - long pssIdle = Debug.getPss(); + ModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + ModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; - List latency = new ArrayList<>(); + mModule = new LlmModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); - long loadStart = System.nanoTime(); - Module module = Module.load(model.getPath()); - int errorCode = module.loadMethod("forward"); - long loadEnd = System.nanoTime(); + mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); + } - for (int i = 0; i < numWarmupIter; i++) { - module.forward(); - } + int generate(String prompt) { + Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } - for (int i = 0; i < numIter; i++) { - long start = System.nanoTime(); - module.forward(); - double forwardMs = (System.nanoTime() - start) * 1e-6; - latency.add(forwardMs); - } + void stop() { + mModule.stop(); + } - final BenchmarkMetric.BenchmarkModel benchmarkModel = - BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); - // The list of metrics we have atm includes: - // Avg inference latency after N iterations - // Currently the result has large variance from outliers, so only use - // 80% samples in the middle (trimmean 0.2) - Collections.sort(latency); - int resultSize = latency.size(); - List usedLatencyResults = latency.subList(resultSize / 10, resultSize * 9 / 10); - - results.add( - new BenchmarkMetric( - benchmarkModel, - "avg_inference_latency(ms)", - latency.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - results.add( - new BenchmarkMetric( - benchmarkModel, - "trimmean_inference_latency(ms)", - usedLatencyResults.stream().mapToDouble(l -> l).average().orElse(0.0f), - 0.0f)); - // Model load time - results.add( - new BenchmarkMetric( - benchmarkModel, "model_load_time(ms)", (loadEnd - loadStart) * 1e-6, 0.0f)); - // Load status - results.add(new BenchmarkMetric(benchmarkModel, "load_status", errorCode, 0)); - // RAM PSS usage - results.add( - new BenchmarkMetric( - benchmarkModel, "ram_pss_usage(mb)", (Debug.getPss() - pssIdle) / 1024, 0)); + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(String result) { + mCallback.onStats(result); + } +} + +class ModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final ModelRunner mModelRunner; + + public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { + super(looper); + mModelRunner = modelRunner; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mModelRunner.mModule.load(); + mModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mModelRunner.mModule.generate((String) msg.obj, mModelRunner); + mModelRunner.mCallback.onGenerationStopped(); + } } } diff --git a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java similarity index 62% rename from extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt rename to extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java index cd2fecdf81c..8503d47ccce 100644 --- a/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmModelRunnerCallback.kt +++ b/extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java @@ -6,21 +6,19 @@ * LICENSE file in the root directory of this source tree. */ - -package org.pytorch.minibench - +package org.pytorch.minibench; /** * A helper interface within the app for MainActivity and Benchmarking to handle callback from * ModelRunner. */ -interface LlmModelRunnerCallback { +public interface ModelRunnerCallback { - fun onModelLoaded(status: Int) + void onModelLoaded(int status); - fun onTokenGenerated(token: String) + void onTokenGenerated(String token); - fun onStats(result: String) + void onStats(String result); - fun onGenerationStopped() + void onGenerationStopped(); } diff --git a/extension/benchmark/android/benchmark/build.gradle.kts b/extension/benchmark/android/benchmark/build.gradle.kts index b1ed5127dfb..ac625be8e02 100644 --- a/extension/benchmark/android/benchmark/build.gradle.kts +++ b/extension/benchmark/android/benchmark/build.gradle.kts @@ -7,6 +7,4 @@ */ // Top-level build file where you can add configuration options common to all sub-projects/modules. -plugins { id("com.android.application") version "8.1.0" apply false - id("org.jetbrains.kotlin.android") version "2.1.10" apply false -} +plugins { id("com.android.application") version "8.1.0" apply false }