Skip to content

Commit 8874de2

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Benchmark app update (#5240)
Summary: Pull Request resolved: #5240 Reviewed By: shoumikhin Differential Revision: D62557247 Pulled By: kirklandsign fbshipit-source-id: 977ced7e241efddb213485abd0fc03392cb3add0
1 parent 8888c0d commit 8874de2

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

extension/android/benchmark/android-llm-device-farm-test-spec.yml

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,21 @@ phases:
1010
commands:
1111
# Prepare the model and the tokenizer
1212
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /sdcard/"
13-
- adb -s $DEVICEFARM_DEVICE_UDID shell "mkdir -p /data/local/tmp/llama/"
14-
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.bin /data/local/tmp/llama/"
15-
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.pte /data/local/tmp/llama/"
16-
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.bin"
17-
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/llama/*.pte"
18-
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /data/local/tmp/llama/"
13+
- adb -s $DEVICEFARM_DEVICE_UDID shell "mkdir -p /data/local/tmp/minibench/"
14+
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.bin /data/local/tmp/minibench/"
15+
- adb -s $DEVICEFARM_DEVICE_UDID shell "mv /sdcard/*.pte /data/local/tmp/minibench/"
16+
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/minibench/*.bin"
17+
- adb -s $DEVICEFARM_DEVICE_UDID shell "chmod 664 /data/local/tmp/minibench/*.pte"
18+
- adb -s $DEVICEFARM_DEVICE_UDID shell "ls -la /data/local/tmp/minibench/"
19+
- adb -s $DEVICEFARM_DEVICE_UDID shell "run-as org.pytorch.minibench rm -rf files"
1920

2021
test:
2122
commands:
2223
# By default, the following ADB command is used by Device Farm to run your Instrumentation test.
2324
# Please refer to Android's documentation for more options on running instrumentation tests with adb:
2425
# https://developer.android.com/studio/test/command-line#run-tests-with-adb
26+
27+
# Run the Instrumentation test for sanity check
2528
- echo "Starting the Instrumentation test"
2629
- |
2730
adb -s $DEVICEFARM_DEVICE_UDID shell "am instrument -r -w --no-window-animation \
@@ -67,17 +70,33 @@ phases:
6770
fi;
6871
6972
# Run the new generic benchmark activity https://developer.android.com/tools/adb#am
70-
- echo "Run LLM benchmark"
73+
- echo "Determine model type"
74+
- |
75+
BIN_FOUND="$(adb -s $DEVICEFARM_DEVICE_UDID shell find /data/local/tmp/minibench/ -name '*.bin')"
76+
if [ -z "$BIN_FOUND" ]; then
77+
echo "No tokenizer files found in /data/local/tmp/minibench/"
78+
else
79+
echo "tokenizer files found in /data/local/tmp/minibench/"
80+
fi
81+
82+
- echo "Run benchmark"
7183
- |
72-
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
73-
--es "model_dir" "/data/local/tmp/llama" \
74-
--es "tokenizer_path" "/data/local/tmp/llama/tokenizer.bin"
84+
adb -s $DEVICEFARM_DEVICE_UDID shell am force-stop org.pytorch.minibench
85+
if [ -z "$BIN_FOUND" ]; then
86+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.BenchmarkActivity \
87+
--es "model_dir" "/data/local/tmp/minibench"
88+
else
89+
adb -s $DEVICEFARM_DEVICE_UDID shell am start -W -n org.pytorch.minibench/.LlmBenchmarkActivity \
90+
--es "model_dir" "/data/local/tmp/minibench" \
91+
--es "tokenizer_path" "/data/local/tmp/minibench/tokenizer.bin"
92+
fi
93+
7594
7695
post_test:
7796
commands:
78-
- echo "Gather LLM benchmark results"
97+
- echo "Gather benchmark results"
7998
- |
80-
BENCHMARK_RESULTS=""
99+
BENCHMARK_RESULTS=$(adb -s $DEVICEFARM_DEVICE_UDID shell run-as org.pytorch.minibench cat files/benchmark_results.json)
81100
ATTEMPT=0
82101
MAX_ATTEMPT=10
83102
while [ -z "${BENCHMARK_RESULTS}" ] && [ $ATTEMPT -lt $MAX_ATTEMPT ]; do

extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
import android.app.Activity;
1212
import android.content.Intent;
1313
import android.os.Bundle;
14+
import com.google.gson.Gson;
1415
import java.io.File;
1516
import java.io.FileWriter;
1617
import java.io.IOException;
18+
import java.util.ArrayList;
1719
import java.util.Arrays;
20+
import java.util.List;
21+
import java.util.stream.Collectors;
1822
import org.pytorch.executorch.Module;
1923

2024
public class BenchmarkActivity extends Activity {
@@ -32,20 +36,39 @@ protected void onCreate(Bundle savedInstanceState) {
3236
int numIter = intent.getIntExtra("num_iter", 10);
3337

3438
// TODO: Format the string with a parsable format
35-
StringBuilder resultText = new StringBuilder();
39+
Stats stats = new Stats();
3640

3741
Module module = Module.load(model.getPath());
3842
for (int i = 0; i < numIter; i++) {
3943
long start = System.currentTimeMillis();
4044
module.forward();
4145
long forwardMs = System.currentTimeMillis() - start;
42-
resultText.append(forwardMs).append(";");
46+
stats.latency.add(forwardMs);
4347
}
4448

49+
// TODO (huydhn): Remove txt files here once the JSON format is ready
4550
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) {
46-
writer.write(resultText.toString());
51+
writer.write(stats.toString());
4752
} catch (IOException e) {
4853
e.printStackTrace();
4954
}
55+
56+
// TODO (huydhn): Figure out on what the final JSON results looks like, we need something
57+
// with the same number of fields as https://github.com/pytorch/pytorch/pull/135042
58+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
59+
Gson gson = new Gson();
60+
writer.write(gson.toJson(stats));
61+
} catch (IOException e) {
62+
e.printStackTrace();
63+
}
64+
}
65+
}
66+
67+
class Stats {
68+
List<Long> latency = new ArrayList<>();
69+
70+
@Override
71+
public String toString() {
72+
return "latency: " + latency.stream().map(Object::toString).collect(Collectors.joining(""));
5073
}
5174
}

0 commit comments

Comments
 (0)