Skip to content

Commit aee35f9

Browse files
committed
Update
[ghstack-poisoned]
2 parents 5db4c4a + 4d42ee4 commit aee35f9

File tree

2 files changed

+63
-36
lines changed

2 files changed

+63
-36
lines changed

.github/workflows/ghstack_land.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
name: Propose to merge ghstack orig PRs to main
22
on:
33
pull_request:
4-
types: [opened, synchronize, closed]
4+
types: [closed]
55
branches:
6-
- 'gh/*/[0-9]+/base'
6+
- 'gh/cccclai/[0-9]+/base'
7+
- 'gh/dbort/[0-9]+/base'
8+
- 'gh/guangy10/[0-9]+/base'
9+
- 'gh/helunwencser/[0-9]+/base'
10+
- 'gh/jorgep31415/[0-9]+/base'
11+
- 'gh/kimishpatel/[0-9]+/base'
12+
- 'gh/kirklandsign/[0-9]+/base'
13+
- 'gh/larryliu0820/[0-9]+/base'
14+
- 'gh/manuelcandales/[0-9]+/base'
15+
- 'gh/mcr229/[0-9]+/base'
16+
- 'gh/swolchok/[0-9]+/base'
17+
- 'gh/SS-JIA/[0-9]+/base'
18+
719
jobs:
820
ghstack_merge_to_main:
921
name: Try to create a PR with ghstack /orig branch

extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import android.app.Activity;
1212
import android.content.Intent;
13+
import android.os.AsyncTask;
1314
import android.os.Bundle;
1415
import android.system.ErrnoException;
1516
import android.system.Os;
@@ -47,43 +48,57 @@ protected void onCreate(Bundle savedInstanceState) {
4748
// TODO: Format the string with a parsable format
4849
Stats stats = new Stats();
4950

50-
// Record the time it takes to load the model and the forward method
51-
stats.loadStart = System.nanoTime();
52-
Module module = Module.load(model.getPath());
53-
stats.errorCode = module.loadMethod("forward");
54-
stats.loadEnd = System.nanoTime();
51+
new AsyncTask<Void, Void, Void>() {
52+
@Override
53+
protected Void doInBackground(Void... voids) {
5554

56-
for (int i = 0; i < numIter; i++) {
57-
long start = System.nanoTime();
58-
module.forward();
59-
double forwardMs = (System.nanoTime() - start) * 1e-6;
60-
stats.latency.add(forwardMs);
61-
}
55+
// Record the time it takes to load the model and the forward method
56+
stats.loadStart = System.nanoTime();
57+
Module module = Module.load(model.getPath());
58+
stats.errorCode = module.loadMethod("forward");
59+
stats.loadEnd = System.nanoTime();
6260

63-
final BenchmarkMetric.BenchmarkModel benchmarkModel =
64-
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
65-
final List<BenchmarkMetric> results = new ArrayList<>();
66-
// The list of metrics we have atm includes:
67-
// Avg inference latency after N iterations
68-
results.add(
69-
new BenchmarkMetric(
70-
benchmarkModel,
71-
"avg_inference_latency(ms)",
72-
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
73-
0.0f));
74-
// Model load time
75-
results.add(
76-
new BenchmarkMetric(
77-
benchmarkModel, "model_load_time(ms)", (stats.loadEnd - stats.loadStart) * 1e-6, 0.0f));
78-
// Load status
79-
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
61+
for (int i = 0; i < numIter; i++) {
62+
long start = System.nanoTime();
63+
module.forward();
64+
double forwardMs = (System.nanoTime() - start) * 1e-6;
65+
stats.latency.add(forwardMs);
66+
}
67+
return null;
68+
}
8069

81-
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
82-
Gson gson = new Gson();
83-
writer.write(gson.toJson(results));
84-
} catch (IOException e) {
85-
e.printStackTrace();
86-
}
70+
@Override
71+
protected void onPostExecute(Void aVoid) {
72+
73+
final BenchmarkMetric.BenchmarkModel benchmarkModel =
74+
BenchmarkMetric.extractBackendAndQuantization(model.getName().replace(".pte", ""));
75+
final List<BenchmarkMetric> results = new ArrayList<>();
76+
// The list of metrics we have atm includes:
77+
// Avg inference latency after N iterations
78+
results.add(
79+
new BenchmarkMetric(
80+
benchmarkModel,
81+
"avg_inference_latency(ms)",
82+
stats.latency.stream().mapToDouble(l -> l).average().orElse(0.0f),
83+
0.0f));
84+
// Model load time
85+
results.add(
86+
new BenchmarkMetric(
87+
benchmarkModel,
88+
"model_load_time(ms)",
89+
(stats.loadEnd - stats.loadStart) * 1e-6,
90+
0.0f));
91+
// Load status
92+
results.add(new BenchmarkMetric(benchmarkModel, "load_status", stats.errorCode, 0));
93+
94+
try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) {
95+
Gson gson = new Gson();
96+
writer.write(gson.toJson(results));
97+
} catch (IOException e) {
98+
e.printStackTrace();
99+
}
100+
}
101+
}.execute();
87102
}
88103
}
89104

0 commit comments

Comments
 (0)