-
Notifications
You must be signed in to change notification settings - Fork 687
Define generic Android benchmark metric structure #5332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
27c895d
57c3c5e
5c5499a
0488a12
60800bd
a529c3b
b2b117e
a66b834
2ea88c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
|
||
import android.app.Activity; | ||
import android.content.Intent; | ||
import android.os.Build; | ||
import android.os.Bundle; | ||
import android.util.Log; | ||
import android.widget.TextView; | ||
|
@@ -18,7 +19,11 @@ | |
import java.io.File; | ||
import java.io.FileWriter; | ||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback { | ||
ModelRunner mModelRunner; | ||
|
@@ -50,19 +55,21 @@ protected void onCreate(Bundle savedInstanceState) { | |
} | ||
|
||
mStatsDump = new StatsDump(); | ||
mStatsDump.name = model.getName().replace(".pte", ""); | ||
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); | ||
mStatsDump.loadStart = System.currentTimeMillis(); | ||
mStatsDump.loadStart = System.nanoTime(); | ||
} | ||
|
||
@Override | ||
public void onModelLoaded(int status) { | ||
mStatsDump.loadEnd = System.currentTimeMillis(); | ||
mStatsDump.loadEnd = System.nanoTime(); | ||
mStatsDump.loadStatus = status; | ||
if (status != 0) { | ||
Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); | ||
onGenerationStopped(); | ||
return; | ||
} | ||
mStatsDump.generateStart = System.currentTimeMillis(); | ||
mStatsDump.generateStart = System.nanoTime(); | ||
mModelRunner.generate(mPrompt); | ||
} | ||
|
||
|
@@ -81,36 +88,116 @@ public void onStats(String stats) { | |
|
||
@Override | ||
public void onGenerationStopped() { | ||
mStatsDump.generateEnd = System.currentTimeMillis(); | ||
mStatsDump.generateEnd = System.nanoTime(); | ||
runOnUiThread( | ||
() -> { | ||
mTextView.append(mStatsDump.toString()); | ||
}); | ||
|
||
// TODO (huydhn): Remove txt files here once the JSON format is ready | ||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { | ||
writer.write(mStatsDump.toString()); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
final BenchmarkMetric.BenchmarkModel benchmarkModel = | ||
BenchmarkMetric.extractBackendAndQuantization(mStatsDump.name); | ||
final List<BenchmarkMetric> results = new ArrayList<>(); | ||
// The list of metrics we have atm includes: | ||
// Load status | ||
results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0)); | ||
// Model load time | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, | ||
"model_load_time(ns)", | ||
mStatsDump.loadEnd - mStatsDump.loadStart, | ||
0.0f)); | ||
// LLM generate time | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, | ||
"generate_time(ns)", | ||
mStatsDump.generateEnd - mStatsDump.generateStart, | ||
0.0f)); | ||
// Token per second | ||
results.add( | ||
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f)); | ||
|
||
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something | ||
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 | ||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { | ||
Gson gson = new Gson(); | ||
writer.write(gson.toJson(mStatsDump)); | ||
writer.write(gson.toJson(results)); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
} | ||
|
||
private double extractTPS(final String tokens) { | ||
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens); | ||
if (m.find()) { | ||
return Double.parseDouble(m.group()); | ||
} else { | ||
return 0.0f; | ||
} | ||
} | ||
} | ||
|
||
class BenchmarkMetric { | ||
public static class BenchmarkModel { | ||
// The model name, i.e. stories110M | ||
String name; | ||
String backend; | ||
String quantization; | ||
|
||
public BenchmarkModel(final String name, final String backend, final String quantization) { | ||
this.name = name; | ||
this.backend = backend; | ||
this.quantization = quantization; | ||
} | ||
} | ||
|
||
BenchmarkModel benchmarkModel; | ||
|
||
// The metric name, i.e. TPS | ||
String metric; | ||
|
||
// The actual value and the option target value | ||
double actualValue; | ||
double targetValue; | ||
|
||
// Let's see which information we want to include here | ||
final String device = Build.BRAND; | ||
// The phone model and Android release version | ||
final String arch = Build.MODEL; | ||
final String os = "Android " + Build.VERSION.RELEASE; | ||
|
||
|
||
public BenchmarkMetric( | ||
final BenchmarkModel benchmarkModel, | ||
final String metric, | ||
final double actualValue, | ||
final double targetValue) { | ||
this.benchmarkModel = benchmarkModel; | ||
this.metric = metric; | ||
this.actualValue = actualValue; | ||
this.targetValue = targetValue; | ||
} | ||
|
||
// TODO (huydhn): Figure out a way to extract the backend and quantization information from | ||
// the .pte model itself instead of parsing its name | ||
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { | ||
final Matcher m = | ||
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model); | ||
if (m.matches()) { | ||
return new BenchmarkMetric.BenchmarkModel( | ||
m.group("name"), m.group("backend"), m.group("quantization")); | ||
} else { | ||
return new BenchmarkMetric.BenchmarkModel(model, "", ""); | ||
} | ||
} | ||
} | ||
|
||
class StatsDump { | ||
int loadStatus; | ||
long loadStart; | ||
long loadEnd; | ||
long generateStart; | ||
long generateEnd; | ||
String tokens; | ||
String name; | ||
|
||
|
||
@NonNull | ||
@Override | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,34 +47,49 @@ protected void onCreate(Bundle savedInstanceState) { | |
// TODO: Format the string with a parsable format | ||
Stats stats = new Stats(); | ||
|
||
// Record the time it takes to load the model and the forward method | ||
stats.loadStart = System.nanoTime(); | ||
Module module = Module.load(model.getPath()); | ||
stats.errorCode = module.loadMethod("forward"); | ||
stats.loadEnd = System.nanoTime(); | ||
|
||
for (int i = 0; i < numIter; i++) { | ||
long start = System.currentTimeMillis(); | ||
long start = System.nanoTime(); | ||
module.forward(); | ||
long forwardMs = System.currentTimeMillis() - start; | ||
long forwardMs = System.nanoTime() - start; | ||
stats.latency.add(forwardMs); | ||
} | ||
stats.errorCode = module.loadMethod("forward"); | ||
|
||
// TODO (huydhn): Remove txt files here once the JSON format is ready | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I should probably log the time for module.loadMethod() before first forward() 😢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, let me try to copy it from llama and add one here too |
||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { | ||
writer.write(stats.toString()); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
final BenchmarkMetric.BenchmarkModel benchmarkModel = | ||
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); | ||
final List<BenchmarkMetric> results = new ArrayList<>(); | ||
// The list of metrics we have atm includes: | ||
// Avg inference latency after N iterations | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, | ||
"avg_inference_latency(ns)", | ||
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), | ||
0.0f)); | ||
// Model load time | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, "model_load_time(ns)", stats.loadEnd - stats.loadStart, 0.0f)); | ||
// Load status | ||
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0)); | ||
|
||
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something | ||
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 | ||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { | ||
Gson gson = new Gson(); | ||
writer.write(gson.toJson(stats)); | ||
writer.write(gson.toJson(results)); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
} | ||
} | ||
|
||
class Stats { | ||
long loadStart; | ||
long loadEnd; | ||
List<Long> latency = new ArrayList<>(); | ||
int errorCode = 0; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
package org.pytorch.minibench; | ||
|
||
import android.os.Build; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
class BenchmarkMetric { | ||
public static class BenchmarkModel { | ||
// The model name, i.e. stories110M | ||
String name; | ||
String backend; | ||
String quantization; | ||
|
||
public BenchmarkModel(final String name, final String backend, final String quantization) { | ||
this.name = name; | ||
this.backend = backend; | ||
this.quantization = quantization; | ||
} | ||
} | ||
|
||
BenchmarkModel benchmarkModel; | ||
|
||
// The metric name, i.e. TPS | ||
String metric; | ||
|
||
// The actual value and the option target value | ||
double actualValue; | ||
double targetValue; | ||
|
||
// Let's see which information we want to include here | ||
final String device = Build.BRAND; | ||
// The phone model and Android release version | ||
final String arch = Build.MODEL; | ||
final String os = "Android " + Build.VERSION.RELEASE; | ||
|
||
public BenchmarkMetric( | ||
final BenchmarkModel benchmarkModel, | ||
final String metric, | ||
final double actualValue, | ||
final double targetValue) { | ||
this.benchmarkModel = benchmarkModel; | ||
this.metric = metric; | ||
this.actualValue = actualValue; | ||
this.targetValue = targetValue; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's okay that we will have to define a duplicate class for iOS app for now. Later we may want to move it to c++ maybe to make it shareble between the iOS and Android app |
||
|
||
// TODO (huydhn): Figure out a way to extract the backend and quantization information from | ||
// the .pte model itself instead of parsing its name | ||
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { | ||
final Matcher m = | ||
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model); | ||
if (m.matches()) { | ||
return new BenchmarkMetric.BenchmarkModel( | ||
m.group("name"), m.group("backend"), m.group("quantization")); | ||
} else { | ||
return new BenchmarkMetric.BenchmarkModel(model, "", ""); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Device brand may not be very useful, I think more details will be needed, e.g. samsung_s22. More device spec would be helpful as well, e.g. RAM, CPU, OS version, etc. At least RAM I think as RAM would be the bottleneck for most edge models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the fields here are flexible, we can definitely add more information about the devices, let me try to add as many as I could find (maybe RAM and CPU info). We can have subsequent PR to add more I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kirklandsign Do you know a way to get the commercial name of the device i.e. s22. The model field kind of map to it, i.e. I search for S901U1 and it means S22, but having a more familiar name make it easier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol, it looks like some additional mapping is needed https://github.com/jaredrummler/AndroidDeviceNames. I think it's better then to do it on the dashboard side in this case (when displaying the device on the dashboard)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@huydhn If gathering any of the additional info is non-trivial, let's leave it for future as we are not shooting to get a perfect metrics measurement by PTC. Let's take the low-hanging fruits and merge this PR. I'd rather prioritize to get the similar structured metrics for iOS app, and get the structured metrics displayed in the CI or in the dashboard(optional if not possible by PTC)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I will update iOS benchmark app after this to generate a similar JSON results