-
Notifications
You must be signed in to change notification settings - Fork 687
Define generic Android benchmark metric structure #5332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
27c895d
57c3c5e
5c5499a
0488a12
60800bd
a529c3b
b2b117e
a66b834
2ea88c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
|
||
import android.app.Activity; | ||
import android.content.Intent; | ||
import android.os.Build; | ||
import android.os.Bundle; | ||
import android.util.Log; | ||
import android.widget.TextView; | ||
|
@@ -18,7 +19,11 @@ | |
import java.io.File; | ||
import java.io.FileWriter; | ||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback { | ||
ModelRunner mModelRunner; | ||
|
@@ -50,6 +55,7 @@ protected void onCreate(Bundle savedInstanceState) { | |
} | ||
|
||
mStatsDump = new StatsDump(); | ||
mStatsDump.name = model.getName().replace(".pte", ""); | ||
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); | ||
mStatsDump.loadStart = System.currentTimeMillis(); | ||
} | ||
|
@@ -87,22 +93,97 @@ public void onGenerationStopped() { | |
mTextView.append(mStatsDump.toString()); | ||
}); | ||
|
||
// TODO (huydhn): Remove txt files here once the JSON format is ready | ||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { | ||
writer.write(mStatsDump.toString()); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
final BenchmarkMetric.BenchmarkModel benchmarkModel = | ||
BenchmarkMetric.extractBackendAndQuantization(mStatsDump.name); | ||
final List<BenchmarkMetric> results = new ArrayList<>(); | ||
// The list of metrics we have atm includes: | ||
// Model load time | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, | ||
"model_load_time(ms)", | ||
mStatsDump.loadEnd - mStatsDump.loadStart, | ||
0.0f)); | ||
// LLM generate time | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, | ||
"generate_time(ms)", | ||
mStatsDump.generateEnd - mStatsDump.generateStart, | ||
0.0f)); | ||
// Token per second | ||
results.add( | ||
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f)); | ||
|
||
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something | ||
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 | ||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { | ||
Gson gson = new Gson(); | ||
writer.write(gson.toJson(mStatsDump)); | ||
writer.write(gson.toJson(results)); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
} | ||
|
||
private double extractTPS(final String tokens) { | ||
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens); | ||
if (m.find()) { | ||
return Double.parseDouble(m.group()); | ||
} else { | ||
return 0.0f; | ||
} | ||
} | ||
} | ||
|
||
class BenchmarkMetric { | ||
public static class BenchmarkModel { | ||
// The model name, i.e. stories110M | ||
String name; | ||
String backend; | ||
String quantization; | ||
|
||
public BenchmarkModel(final String name, final String backend, final String quantization) { | ||
this.name = name; | ||
this.backend = backend; | ||
this.quantization = quantization; | ||
} | ||
} | ||
|
||
BenchmarkModel benchmarkModel; | ||
|
||
// The metric name, i.e. TPS | ||
String metric; | ||
|
||
// The actual value and the option target value | ||
double actual; | ||
double target; | ||
|
||
// Let's see which information we want to include here | ||
final String device = Build.BRAND; | ||
|
||
// The phone model and Android release version | ||
final String arch = Build.MODEL + " / " + Build.VERSION.RELEASE; | ||
|
||
public BenchmarkMetric( | ||
final BenchmarkModel benchmarkModel, | ||
final String metric, | ||
final double actual, | ||
final double target) { | ||
this.benchmarkModel = benchmarkModel; | ||
this.metric = metric; | ||
this.actual = actual; | ||
this.target = target; | ||
} | ||
|
||
// TODO (huydhn): Figure out a way to extract the backend and quantization information from | ||
// the .pte model itself instead of parsing its name | ||
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { | ||
final Matcher m = | ||
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model); | ||
if (m.matches()) { | ||
return new BenchmarkMetric.BenchmarkModel( | ||
m.group("name"), m.group("backend"), m.group("quantization")); | ||
} else { | ||
return new BenchmarkMetric.BenchmarkModel(model, "", ""); | ||
} | ||
} | ||
} | ||
|
||
class StatsDump { | ||
|
@@ -111,6 +192,7 @@ class StatsDump { | |
long generateStart; | ||
long generateEnd; | ||
String tokens; | ||
String name; | ||
|
||
|
||
@NonNull | ||
@Override | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,18 +46,21 @@ protected void onCreate(Bundle savedInstanceState) { | |
stats.latency.add(forwardMs); | ||
} | ||
|
||
// TODO (huydhn): Remove txt files here once the JSON format is ready | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I should probably log the time for module.loadMethod() before first forward() 😢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, let me try to copy it from llama and add one here too |
||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { | ||
writer.write(stats.toString()); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
final BenchmarkMetric.BenchmarkModel benchmarkModel = | ||
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", "")); | ||
final List<BenchmarkMetric> results = new ArrayList<>(); | ||
// The list of metrics we have atm includes: | ||
// Avg inference latency after N iterations | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, | ||
"avg_inference_latency(ms)", | ||
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f), | ||
0.0f)); | ||
|
||
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something | ||
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 | ||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { | ||
Gson gson = new Gson(); | ||
writer.write(gson.toJson(stats)); | ||
writer.write(gson.toJson(results)); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
package org.pytorch.minibench; | ||
|
||
import android.os.Build; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
class BenchmarkMetric { | ||
public static class BenchmarkModel { | ||
// The model name, i.e. stories110M | ||
String name; | ||
String backend; | ||
String quantization; | ||
|
||
public BenchmarkModel(final String name, final String backend, final String quantization) { | ||
this.name = name; | ||
this.backend = backend; | ||
this.quantization = quantization; | ||
} | ||
} | ||
|
||
BenchmarkModel benchmarkModel; | ||
|
||
// The metric name, i.e. TPS | ||
String metric; | ||
|
||
// The actual value and the option target value | ||
double actual; | ||
double target; | ||
|
||
// Let's see which information we want to include here | ||
final String device = Build.BRAND; | ||
// The phone model and Android release version | ||
final String arch = Build.MODEL + " / " + Build.VERSION.RELEASE; | ||
|
||
public BenchmarkMetric( | ||
final BenchmarkModel benchmarkModel, | ||
final String metric, | ||
final double actual, | ||
final double target) { | ||
this.benchmarkModel = benchmarkModel; | ||
this.metric = metric; | ||
this.actual = actual; | ||
this.target = target; | ||
} | ||
|
||
// TODO (huydhn): Figure out a way to extract the backend and quantization information from | ||
// the .pte model itself instead of parsing its name | ||
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) { | ||
final Matcher m = | ||
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model); | ||
if (m.matches()) { | ||
return new BenchmarkMetric.BenchmarkModel( | ||
m.group("name"), m.group("backend"), m.group("quantization")); | ||
} else { | ||
return new BenchmarkMetric.BenchmarkModel(model, "", ""); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,11 @@ | |
import java.io.File; | ||
import java.io.FileWriter; | ||
import java.io.IOException; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.regex.Matcher; | ||
import java.util.regex.Pattern; | ||
|
||
public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { | ||
ModelRunner mModelRunner; | ||
|
@@ -45,6 +49,7 @@ protected void onCreate(Bundle savedInstanceState) { | |
} | ||
|
||
mStatsInfo = new StatsInfo(); | ||
mStatsInfo.name = model.getName().replace(".pte", ""); | ||
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); | ||
mStatsInfo.loadStart = System.currentTimeMillis(); | ||
} | ||
|
@@ -73,22 +78,44 @@ public void onStats(String stats) { | |
public void onGenerationStopped() { | ||
mStatsInfo.generateEnd = System.currentTimeMillis(); | ||
|
||
// TODO (huydhn): Remove txt files here once the JSON format is ready | ||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { | ||
writer.write(mStatsInfo.toString()); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
final BenchmarkMetric.BenchmarkModel benchmarkModel = | ||
BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.name); | ||
final List<BenchmarkMetric> results = new ArrayList<>(); | ||
// The list of metrics we have atm includes: | ||
// Model load time | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, | ||
"model_load_time(ms)", | ||
mStatsInfo.loadEnd - mStatsInfo.loadStart, | ||
0.0f)); | ||
// LLM generate time | ||
results.add( | ||
new BenchmarkMetric( | ||
benchmarkModel, | ||
"generate_time(ms)", | ||
mStatsInfo.generateEnd - mStatsInfo.generateStart, | ||
0.0f)); | ||
// Token per second | ||
results.add( | ||
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsInfo.tokens), 0.0f)); | ||
|
||
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something | ||
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 | ||
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { | ||
Gson gson = new Gson(); | ||
writer.write(gson.toJson(mStatsInfo)); | ||
writer.write(gson.toJson(results)); | ||
} catch (IOException e) { | ||
e.printStackTrace(); | ||
} | ||
} | ||
|
||
private double extractTPS(final String tokens) { | ||
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens); | ||
if (m.find()) { | ||
return Double.parseDouble(m.group()); | ||
} else { | ||
return 0.0f; | ||
} | ||
} | ||
} | ||
|
||
class StatsInfo { | ||
|
@@ -97,6 +124,7 @@ class StatsInfo { | |
long generateStart; | ||
long generateEnd; | ||
String tokens; | ||
String name; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can add a "loadStatus" for in case of failure |
||
@Override | ||
public String toString() { | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: actualValue, targetValue to reduce future user questions? Or should we give a comment to pointer to https://github.com/pytorch/pytorch/blob/main/benchmarks/gpt_fast/benchmark.py#L25 lol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use any name here I guess. I need to write a script to insert the JSON into the database later, so some transformation can be done at that stage instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, when I looked at the example in the PR summary, I was having the same question what "target" refers to. In ET target typically means the target device. So yeah, it would be nice to make the name self-explain