Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import android.app.Activity;
import android.content.Intent;
import android.os.Build;
import android.os.Bundle;
import android.util.Log;
import android.widget.TextView;
Expand All @@ -18,7 +19,11 @@
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
ModelRunner mModelRunner;
Expand Down Expand Up @@ -50,6 +55,7 @@ protected void onCreate(Bundle savedInstanceState) {
}

mStatsDump = new StatsDump();
mStatsDump.name = model.getName().replace(".pte", "");
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
mStatsDump.loadStart = System.currentTimeMillis();
}
Expand Down Expand Up @@ -87,22 +93,97 @@ public void onGenerationStopped() {
mTextView.append(mStatsDump.toString());
});

// TODO (huydhn): Remove txt files here once the JSON format is ready
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(mStatsDump.toString());
} catch (IOException e) {
e.printStackTrace();
}
final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(mStatsDump.name);
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ms)",
mStatsDump.loadEnd - mStatsDump.loadStart,
0.0f));
// LLM generate time
results.add(
new BenchmarkMetric(
benchmarkModel,
"generate_time(ms)",
mStatsDump.generateEnd - mStatsDump.generateStart,
0.0f));
// Token per second
results.add(
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f));

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mStatsDump));
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}

private double extractTPS(final String tokens) {
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
if (m.find()) {
return Double.parseDouble(m.group());
} else {
return 0.0f;
}
}
}

class BenchmarkMetric {
public static class BenchmarkModel {
// The model name, i.e. stories110M
String name;
String backend;
String quantization;

public BenchmarkModel(final String name, final String backend, final String quantization) {
this.name = name;
this.backend = backend;
this.quantization = quantization;
}
}

BenchmarkModel benchmarkModel;

// The metric name, i.e. TPS
String metric;

// The actual value and the option target value
double actual;
double target;
Copy link
Contributor

@kirklandsign kirklandsign Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: actualValue, targetValue to reduce future user questions? Or should we give a comment to pointer to https://github.com/pytorch/pytorch/blob/main/benchmarks/gpt_fast/benchmark.py#L25 lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use any name here I guess. I need to write a script to insert the JSON into the database later, so some transformation can be done at that stage instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, when I looked at the example in the PR summary, I was having the same question what "target" refers to. In ET target typically means the target device. So yeah, it would be nice to make the name self-explain


// Let's see which information we want to include here
final String device = Build.BRAND;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device brand may not be very useful, I think more details will be needed, e.g. samsung_s22. More device spec would be helpful as well, e.g. RAM, CPU, OS version, etc. At least RAM I think as RAM would be the bottleneck for most edge models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the fields here are flexible, we can definitely add more information about the devices, let me try to add as many as I could find (maybe RAM and CPU info). We can have subsequent PR to add more I guess

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kirklandsign Do you know a way to get the commercial name of the device i.e. s22. The model field kind of map to it, i.e. I search for S901U1 and it means S22, but having a more familiar name make it easier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol, it looks like some additional mapping is needed https://github.com/jaredrummler/AndroidDeviceNames. I think it's better then to do it on the dashboard side in this case (when displaying the device on the dashboard)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huydhn If gathering any of the additional info is non-trivial, let's leave it for future as we are not shooting to get a perfect metrics measurement by PTC. Let's take the low-hanging fruits and merge this PR. I'd rather prioritize to get the similar structured metrics for iOS app, and get the structured metrics displayed in the CI or in the dashboard(optional if not possible by PTC)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I will update iOS benchmark app after this to generate a similar JSON results

// The phone model and Android release version
final String arch = Build.MODEL + " / " + Build.VERSION.RELEASE;

public BenchmarkMetric(
final BenchmarkModel benchmarkModel,
final String metric,
final double actual,
final double target) {
this.benchmarkModel = benchmarkModel;
this.metric = metric;
this.actual = actual;
this.target = target;
}

// TODO (huydhn): Figure out a way to extract the backend and quantization information from
// the .pte model itself instead of parsing its name
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
final Matcher m =
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
if (m.matches()) {
return new BenchmarkMetric.BenchmarkModel(
m.group("name"), m.group("backend"), m.group("quantization"));
} else {
return new BenchmarkMetric.BenchmarkModel(model, "", "");
}
}
}

class StatsDump {
Expand All @@ -111,6 +192,7 @@ class StatsDump {
long generateStart;
long generateEnd;
String tokens;
String name;
Copy link
Contributor

@guangy10 guangy10 Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name of what? Can we add a comment or make it self-explain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's the model name, let me update the variable to call it so


@NonNull
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,21 @@ protected void onCreate(Bundle savedInstanceState) {
stats.latency.add(forwardMs);
}

// TODO (huydhn): Remove txt files here once the JSON format is ready
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably log the time for module.loadMethod() before first forward() 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let me try to copy it from llama and add one here too

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(stats.toString());
} catch (IOException e) {
e.printStackTrace();
}
final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ms)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(stats));
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.minibench;

import android.os.Build;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

class BenchmarkMetric {
public static class BenchmarkModel {
// The model name, i.e. stories110M
String name;
String backend;
String quantization;

public BenchmarkModel(final String name, final String backend, final String quantization) {
this.name = name;
this.backend = backend;
this.quantization = quantization;
}
}

BenchmarkModel benchmarkModel;

// The metric name, i.e. TPS
String metric;

// The actual value and the option target value
double actual;
double target;

// Let's see which information we want to include here
final String device = Build.BRAND;
// The phone model and Android release version
final String arch = Build.MODEL + " / " + Build.VERSION.RELEASE;

public BenchmarkMetric(
final BenchmarkModel benchmarkModel,
final String metric,
final double actual,
final double target) {
this.benchmarkModel = benchmarkModel;
this.metric = metric;
this.actual = actual;
this.target = target;
}

// TODO (huydhn): Figure out a way to extract the backend and quantization information from
// the .pte model itself instead of parsing its name
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
final Matcher m =
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
if (m.matches()) {
return new BenchmarkMetric.BenchmarkModel(
m.group("name"), m.group("backend"), m.group("quantization"));
} else {
return new BenchmarkMetric.BenchmarkModel(model, "", "");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback {
ModelRunner mModelRunner;
Expand Down Expand Up @@ -45,6 +49,7 @@ protected void onCreate(Bundle savedInstanceState) {
}

mStatsInfo = new StatsInfo();
mStatsInfo.name = model.getName().replace(".pte", "");
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
mStatsInfo.loadStart = System.currentTimeMillis();
}
Expand Down Expand Up @@ -73,22 +78,44 @@ public void onStats(String stats) {
public void onGenerationStopped() {
mStatsInfo.generateEnd = System.currentTimeMillis();

// TODO (huydhn): Remove txt files here once the JSON format is ready
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(mStatsInfo.toString());
} catch (IOException e) {
e.printStackTrace();
}
final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(mStatsInfo.name);
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ms)",
mStatsInfo.loadEnd - mStatsInfo.loadStart,
0.0f));
// LLM generate time
results.add(
new BenchmarkMetric(
benchmarkModel,
"generate_time(ms)",
mStatsInfo.generateEnd - mStatsInfo.generateStart,
0.0f));
// Token per second
results.add(
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsInfo.tokens), 0.0f));

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mStatsInfo));
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}

private double extractTPS(final String tokens) {
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
if (m.find()) {
return Double.parseDouble(m.group());
} else {
return 0.0f;
}
}
}

class StatsInfo {
Expand All @@ -97,6 +124,7 @@ class StatsInfo {
long generateStart;
long generateEnd;
String tokens;
String name;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a "loadStatus" for

in case of failure

@Override
public String toString() {
Expand Down
Loading