Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import android.app.Activity;
import android.content.Intent;
import android.os.Build;
import android.os.Bundle;
import android.util.Log;
import android.widget.TextView;
Expand All @@ -18,7 +19,11 @@
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class LlmBenchmarkRunner extends Activity implements ModelRunnerCallback {
ModelRunner mModelRunner;
Expand Down Expand Up @@ -50,19 +55,21 @@ protected void onCreate(Bundle savedInstanceState) {
}

mStatsDump = new StatsDump();
mStatsDump.name = model.getName().replace(".pte", "");
mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this);
mStatsDump.loadStart = System.currentTimeMillis();
mStatsDump.loadStart = System.nanoTime();
}

@Override
public void onModelLoaded(int status) {
mStatsDump.loadEnd = System.currentTimeMillis();
mStatsDump.loadEnd = System.nanoTime();
mStatsDump.loadStatus = status;
if (status != 0) {
Log.e("LlmBenchmarkRunner", "Loaded failed: " + status);
onGenerationStopped();
return;
}
mStatsDump.generateStart = System.currentTimeMillis();
mStatsDump.generateStart = System.nanoTime();
mModelRunner.generate(mPrompt);
}

Expand All @@ -81,36 +88,116 @@ public void onStats(String stats) {

@Override
public void onGenerationStopped() {
mStatsDump.generateEnd = System.currentTimeMillis();
mStatsDump.generateEnd = System.nanoTime();
runOnUiThread(
() -> {
mTextView.append(mStatsDump.toString());
});

// TODO (huydhn): Remove txt files here once the JSON format is ready
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(mStatsDump.toString());
} catch (IOException e) {
e.printStackTrace();
}
final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(mStatsDump.name);
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", mStatsDump.loadStatus, 0));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel,
"model_load_time(ns)",
mStatsDump.loadEnd - mStatsDump.loadStart,
0.0f));
// LLM generate time
results.add(
new BenchmarkMetric(
benchmarkModel,
"generate_time(ns)",
mStatsDump.generateEnd - mStatsDump.generateStart,
0.0f));
// Token per second
results.add(
new BenchmarkMetric(benchmarkModel, "token_per_sec", extractTPS(mStatsDump.tokens), 0.0f));

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(mStatsDump));
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}

private double extractTPS(final String tokens) {
final Matcher m = Pattern.compile("\\d+\\.?\\d*").matcher(tokens);
if (m.find()) {
return Double.parseDouble(m.group());
} else {
return 0.0f;
}
}
}

class BenchmarkMetric {
public static class BenchmarkModel {
// The model name, i.e. stories110M
String name;
String backend;
String quantization;

public BenchmarkModel(final String name, final String backend, final String quantization) {
this.name = name;
this.backend = backend;
this.quantization = quantization;
}
}

BenchmarkModel benchmarkModel;

// The metric name, i.e. TPS
String metric;

// The actual value and the option target value
double actualValue;
double targetValue;

// Let's see which information we want to include here
final String device = Build.BRAND;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device brand may not be very useful, I think more details will be needed, e.g. samsung_s22. More device spec would be helpful as well, e.g. RAM, CPU, OS version, etc. At least RAM I think as RAM would be the bottleneck for most edge models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the fields here are flexible, we can definitely add more information about the devices, let me try to add as many as I could find (maybe RAM and CPU info). We can have subsequent PR to add more I guess

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kirklandsign Do you know a way to get the commercial name of the device i.e. s22. The model field kind of map to it, i.e. I search for S901U1 and it means S22, but having a more familiar name make it easier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol, it looks like some additional mapping is needed https://github.com/jaredrummler/AndroidDeviceNames. I think it's better then to do it on the dashboard side in this case (when displaying the device on the dashboard)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@huydhn If gathering any of the additional info is non-trivial, let's leave it for future as we are not shooting to get a perfect metrics measurement by PTC. Let's take the low-hanging fruits and merge this PR. I'd rather prioritize to get the similar structured metrics for iOS app, and get the structured metrics displayed in the CI or in the dashboard(optional if not possible by PTC)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I will update iOS benchmark app after this to generate a similar JSON results

// The phone model and Android release version
final String arch = Build.MODEL;
final String os = "Android " + Build.VERSION.RELEASE;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay OS version is covered


public BenchmarkMetric(
final BenchmarkModel benchmarkModel,
final String metric,
final double actualValue,
final double targetValue) {
this.benchmarkModel = benchmarkModel;
this.metric = metric;
this.actualValue = actualValue;
this.targetValue = targetValue;
}

// TODO (huydhn): Figure out a way to extract the backend and quantization information from
// the .pte model itself instead of parsing its name
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
final Matcher m =
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
if (m.matches()) {
return new BenchmarkMetric.BenchmarkModel(
m.group("name"), m.group("backend"), m.group("quantization"));
} else {
return new BenchmarkMetric.BenchmarkModel(model, "", "");
}
}
}

class StatsDump {
int loadStatus;
long loadStart;
long loadEnd;
long generateStart;
long generateEnd;
String tokens;
String name;
Copy link
Contributor

@guangy10 guangy10 Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name of what? Can we add a comment or make it self-explain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's the model name, let me update the variable to call it so


@NonNull
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,34 +47,49 @@ protected void onCreate(Bundle savedInstanceState) {
// TODO: Format the string with a parsable format
Stats stats = new Stats();

// Record the time it takes to load the model and the forward method
stats.loadStart = System.nanoTime();
Module module = Module.load(model.getPath());
stats.errorCode = module.loadMethod("forward");
stats.loadEnd = System.nanoTime();

for (int i = 0; i < numIter; i++) {
long start = System.currentTimeMillis();
long start = System.nanoTime();
module.forward();
long forwardMs = System.currentTimeMillis() - start;
long forwardMs = System.nanoTime() - start;
stats.latency.add(forwardMs);
}
stats.errorCode = module.loadMethod("forward");

// TODO (huydhn): Remove txt files here once the JSON format is ready
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably log the time for module.loadMethod() before first forward() 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, let me try to copy it from llama and add one here too

try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
writer.write(stats.toString());
} catch (IOException e) {
e.printStackTrace();
}
final BenchmarkMetric.BenchmarkModel benchmarkModel =
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
final List<BenchmarkMetric> results = new ArrayList<>();
// The list of metrics we have atm includes:
// Avg inference latency after N iterations
results.add(
new BenchmarkMetric(
benchmarkModel,
"avg_inference_latency(ns)",
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
0.0f));
// Model load time
results.add(
new BenchmarkMetric(
benchmarkModel, "model_load_time(ns)", stats.loadEnd - stats.loadStart, 0.0f));
// Load status
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));

// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
Gson gson = new Gson();
writer.write(gson.toJson(stats));
writer.write(gson.toJson(results));
} catch (IOException e) {
e.printStackTrace();
}
}
}

class Stats {
long loadStart;
long loadEnd;
List<Long> latency = new ArrayList<>();
int errorCode = 0;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.minibench;

import android.os.Build;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

class BenchmarkMetric {
public static class BenchmarkModel {
// The model name, i.e. stories110M
String name;
String backend;
String quantization;

public BenchmarkModel(final String name, final String backend, final String quantization) {
this.name = name;
this.backend = backend;
this.quantization = quantization;
}
}

BenchmarkModel benchmarkModel;

// The metric name, i.e. TPS
String metric;

// The actual value and the option target value
double actualValue;
double targetValue;

// Let's see which information we want to include here
final String device = Build.BRAND;
// The phone model and Android release version
final String arch = Build.MODEL;
final String os = "Android " + Build.VERSION.RELEASE;

public BenchmarkMetric(
final BenchmarkModel benchmarkModel,
final String metric,
final double actualValue,
final double targetValue) {
this.benchmarkModel = benchmarkModel;
this.metric = metric;
this.actualValue = actualValue;
this.targetValue = targetValue;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay that we will have to define a duplicate class for iOS app for now. Later we may want to move it to c++ maybe to make it shareble between the iOS and Android app


// TODO (huydhn): Figure out a way to extract the backend and quantization information from
// the .pte model itself instead of parsing its name
public static BenchmarkMetric.BenchmarkModel extractBackendAndQuantization(final String model) {
final Matcher m =
Pattern.compile("(?<name>\\w+)_(?<backend>\\w+)_(?<quantization>\\w+)").matcher(model);
if (m.matches()) {
return new BenchmarkMetric.BenchmarkModel(
m.group("name"), m.group("backend"), m.group("quantization"));
} else {
return new BenchmarkMetric.BenchmarkModel(model, "", "");
}
}
}
Loading
Loading