1010
1111import android .app .Activity ;
1212import android .content .Intent ;
13+ import android .os .AsyncTask ;
1314import android .os .Bundle ;
14- import android .os .Handler ;
15- import android .os .HandlerThread ;
16- import android .os .Looper ;
15+ import android .os .Debug ;
1716import android .system .ErrnoException ;
1817import android .system .Os ;
1918import com .google .gson .Gson ;
2221import java .io .IOException ;
2322import java .util .ArrayList ;
2423import java .util .Arrays ;
24+ import java .util .Collections ;
2525import java .util .List ;
26+ import java .util .stream .Collectors ;
27+ import org .pytorch .executorch .Module ;
2628
2729public class BenchmarkActivity extends Activity {
28-
29- File mModel ;
30- int mNumIter ;
31- int mNumWarmupIter ;
32- String mTokenizerPath ;
33- float mTemperature ;
34- String mPrompt ;
35-
36- HandlerThread mHandlerThread ;
37- BenchmarkHandler mHandler ;
38-
39- List <BenchmarkMetric > mResult ;
40-
4130 @ Override
4231 protected void onCreate (Bundle savedInstanceState ) {
4332 super .onCreate (savedInstanceState );
@@ -58,79 +47,95 @@ protected void onCreate(Bundle savedInstanceState) {
5847
5948 int numIter = intent .getIntExtra ("num_iter" , 50 );
6049 int numWarmupIter = intent .getIntExtra ("num_warm_up_iter" , 10 );
61- String tokenizerPath = intent .getStringExtra ("tokenizer_path" );
62- float temperature = intent .getFloatExtra ("temperature" , 0.8f );
63- String prompt = intent .getStringExtra ("prompt" );
64-
65- mModel = model ;
66- mNumIter = numIter ;
67- mNumWarmupIter = numWarmupIter ;
68- mTokenizerPath = tokenizerPath ;
69- mTemperature = temperature ;
70- mPrompt = prompt ;
71- if (mPrompt == null ) {
72- mPrompt = "The ultimate answer" ;
73- }
74- mResult = new ArrayList <>();
7550
76- mHandlerThread = new HandlerThread ("ModelRunner" );
77- mHandlerThread .start ();
78- mHandler = new BenchmarkHandler (mHandlerThread .getLooper (), this );
51+ long pssIdle = Debug .getPss ();
7952
80- mHandler . sendEmptyMessage ( BenchmarkHandler . MESSAGE_RUN_BENCHMARK );
81- }
53+ // TODO: Format the string with a parsable format
54+ Stats stats = new Stats ();
8255
83- void writeResult () {
84- try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
85- Gson gson = new Gson ();
86- writer .write (gson .toJson (mResult ));
87- } catch (IOException e ) {
88- e .printStackTrace ();
89- } finally {
90- finish ();
91- }
92- }
93- }
56+ new AsyncTask <Void , Void , Void >() {
57+ @ Override
58+ protected Void doInBackground (Void ... voids ) {
9459
95- class BenchmarkHandler extends Handler {
96- public static int MESSAGE_RUN_BENCHMARK = 1 ;
97- public static int MESSAGE_LLM_RUN_BENCHMARK = 2 ;
60+ // Record the time it takes to load the model and the forward method
61+ stats .loadStart = System .nanoTime ();
62+ Module module = Module .load (model .getPath ());
63+ stats .errorCode = module .loadMethod ("forward" );
64+ stats .loadEnd = System .nanoTime ();
9865
99- ModelRunner mModelRunner ;
100- BenchmarkActivity mBenchmarkActivity ;
66+ for (int i = 0 ; i < numWarmupIter ; i ++) {
67+ module .forward ();
68+ }
10169
102- LlmModelRunner mLlmModelRunner ;
103- LlmBenchmark mLlmBenchmark ;
70+ for (int i = 0 ; i < numIter ; i ++) {
71+ long start = System .nanoTime ();
72+ module .forward ();
73+ double forwardMs = (System .nanoTime () - start ) * 1e-6 ;
74+ stats .latency .add (forwardMs );
75+ }
76+ return null ;
77+ }
10478
105- public BenchmarkHandler (Looper looper , BenchmarkActivity benchmarkActivity ) {
106- super (looper );
107- mModelRunner = new ModelRunner ();
108- mBenchmarkActivity = benchmarkActivity ;
79+ @ Override
80+ protected void onPostExecute (Void aVoid ) {
81+
82+ final BenchmarkMetric .BenchmarkModel benchmarkModel =
83+ BenchmarkMetric .extractBackendAndQuantization (model .getName ().replace (".pte" , "" ));
84+ final List <BenchmarkMetric > results = new ArrayList <>();
85+ // The list of metrics we have atm includes:
86+ // Avg inference latency after N iterations
87+ // Currently the result has large variance from outliers, so only use
88+ // 80% samples in the middle (trimmean 0.2)
89+ Collections .sort (stats .latency );
90+ int resultSize = stats .latency .size ();
91+ List <Double > usedLatencyResults =
92+ stats .latency .subList (resultSize / 10 , resultSize * 9 / 10 );
93+
94+ results .add (
95+ new BenchmarkMetric (
96+ benchmarkModel ,
97+ "avg_inference_latency(ms)" ,
98+ stats .latency .stream ().mapToDouble (l -> l ).average ().orElse (0.0f ),
99+ 0.0f ));
100+ results .add (
101+ new BenchmarkMetric (
102+ benchmarkModel ,
103+ "trimmean_inference_latency(ms)" ,
104+ usedLatencyResults .stream ().mapToDouble (l -> l ).average ().orElse (0.0f ),
105+ 0.0f ));
106+ // Model load time
107+ results .add (
108+ new BenchmarkMetric (
109+ benchmarkModel ,
110+ "model_load_time(ms)" ,
111+ (stats .loadEnd - stats .loadStart ) * 1e-6 ,
112+ 0.0f ));
113+ // Load status
114+ results .add (new BenchmarkMetric (benchmarkModel , "load_status" , stats .errorCode , 0 ));
115+ // RAM PSS usage
116+ results .add (
117+ new BenchmarkMetric (
118+ benchmarkModel , "ram_pss_usage(mb)" , (Debug .getPss () - pssIdle ) / 1024 , 0 ));
119+
120+ try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
121+ Gson gson = new Gson ();
122+ writer .write (gson .toJson (results ));
123+ } catch (IOException e ) {
124+ e .printStackTrace ();
125+ }
126+ }
127+ }.execute ();
109128 }
129+ }
130+
131+ class Stats {
132+ long loadStart ;
133+ long loadEnd ;
134+ List <Double > latency = new ArrayList <>();
135+ int errorCode = 0 ;
110136
111137 @ Override
112- public void handleMessage (android .os .Message msg ) {
113- if (msg .what == MESSAGE_RUN_BENCHMARK ) {
114- mModelRunner .runBenchmark (
115- mBenchmarkActivity .mModel ,
116- mBenchmarkActivity .mNumWarmupIter ,
117- mBenchmarkActivity .mNumIter ,
118- mBenchmarkActivity .mResult );
119-
120- if (mBenchmarkActivity .mTokenizerPath == null ) {
121- mBenchmarkActivity .writeResult ();
122- } else {
123- this .sendEmptyMessage (MESSAGE_LLM_RUN_BENCHMARK );
124- }
125- } else if (msg .what == MESSAGE_LLM_RUN_BENCHMARK ) {
126- mLlmBenchmark =
127- new LlmBenchmark (
128- mBenchmarkActivity ,
129- mBenchmarkActivity .mModel .getPath (),
130- mBenchmarkActivity .mTokenizerPath ,
131- mBenchmarkActivity .mPrompt ,
132- mBenchmarkActivity .mTemperature ,
133- mBenchmarkActivity .mResult );
134- }
138+ public String toString () {
139+ return "latency: " + latency .stream ().map (Object ::toString ).collect (Collectors .joining ("" ));
135140 }
136141}
0 commit comments