1010
1111import android .app .Activity ;
1212import android .content .Intent ;
13- import android .os .AsyncTask ;
1413import android .os .Bundle ;
15- import android .os .Debug ;
14+ import android .os .Handler ;
15+ import android .os .HandlerThread ;
16+ import android .os .Looper ;
1617import android .system .ErrnoException ;
1718import android .system .Os ;
1819import com .google .gson .Gson ;
2122import java .io .IOException ;
2223import java .util .ArrayList ;
2324import java .util .Arrays ;
24- import java .util .Collections ;
2525import java .util .List ;
26- import java .util .stream .Collectors ;
27- import org .pytorch .executorch .Module ;
2826
2927public class BenchmarkActivity extends Activity {
28+
29+ File mModel ;
30+ int mNumIter ;
31+ int mNumWarmupIter ;
32+ String mTokenizerPath ;
33+ float mTemperature ;
34+ String mPrompt ;
35+
36+ HandlerThread mHandlerThread ;
37+ BenchmarkHandler mHandler ;
38+
39+ List <BenchmarkMetric > mResult ;
40+
3041 @ Override
3142 protected void onCreate (Bundle savedInstanceState ) {
3243 super .onCreate (savedInstanceState );
@@ -47,95 +58,79 @@ protected void onCreate(Bundle savedInstanceState) {
4758
4859 int numIter = intent .getIntExtra ("num_iter" , 50 );
4960 int numWarmupIter = intent .getIntExtra ("num_warm_up_iter" , 10 );
61+ String tokenizerPath = intent .getStringExtra ("tokenizer_path" );
62+ float temperature = intent .getFloatExtra ("temperature" , 0.8f );
63+ String prompt = intent .getStringExtra ("prompt" );
64+
65+ mModel = model ;
66+ mNumIter = numIter ;
67+ mNumWarmupIter = numWarmupIter ;
68+ mTokenizerPath = tokenizerPath ;
69+ mTemperature = temperature ;
70+ mPrompt = prompt ;
71+ if (mPrompt == null ) {
72+ mPrompt = "The ultimate answer" ;
73+ }
74+ mResult = new ArrayList <>();
5075
51- long pssIdle = Debug .getPss ();
76+ mHandlerThread = new HandlerThread ("ModelRunner" );
77+ mHandlerThread .start ();
78+ mHandler = new BenchmarkHandler (mHandlerThread .getLooper (), this );
5279
53- // TODO: Format the string with a parsable format
54- Stats stats = new Stats ();
80+ mHandler . sendEmptyMessage ( BenchmarkHandler . MESSAGE_RUN_BENCHMARK );
81+ }
5582
56- new AsyncTask <Void , Void , Void >() {
57- @ Override
58- protected Void doInBackground (Void ... voids ) {
83+ void writeResult () {
84+ try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
85+ Gson gson = new Gson ();
86+ writer .write (gson .toJson (mResult ));
87+ } catch (IOException e ) {
88+ e .printStackTrace ();
89+ } finally {
90+ finish ();
91+ }
92+ }
93+ }
5994
60- // Record the time it takes to load the model and the forward method
61- stats .loadStart = System .nanoTime ();
62- Module module = Module .load (model .getPath ());
63- stats .errorCode = module .loadMethod ("forward" );
64- stats .loadEnd = System .nanoTime ();
95+ class BenchmarkHandler extends Handler {
96+ public static int MESSAGE_RUN_BENCHMARK = 1 ;
97+ public static int MESSAGE_LLM_RUN_BENCHMARK = 2 ;
6598
66- for (int i = 0 ; i < numWarmupIter ; i ++) {
67- module .forward ();
68- }
99+ ModelRunner mModelRunner ;
100+ BenchmarkActivity mBenchmarkActivity ;
69101
70- for (int i = 0 ; i < numIter ; i ++) {
71- long start = System .nanoTime ();
72- module .forward ();
73- double forwardMs = (System .nanoTime () - start ) * 1e-6 ;
74- stats .latency .add (forwardMs );
75- }
76- return null ;
77- }
102+ LlmModelRunner mLlmModelRunner ;
103+ LlmBenchmark mLlmBenchmark ;
78104
79- @ Override
80- protected void onPostExecute (Void aVoid ) {
81-
82- final BenchmarkMetric .BenchmarkModel benchmarkModel =
83- BenchmarkMetric .extractBackendAndQuantization (model .getName ().replace (".pte" , "" ));
84- final List <BenchmarkMetric > results = new ArrayList <>();
85- // The list of metrics we have atm includes:
86- // Avg inference latency after N iterations
87- // Currently the result has large variance from outliers, so only use
88- // 80% samples in the middle (trimmean 0.2)
89- Collections .sort (stats .latency );
90- int resultSize = stats .latency .size ();
91- List <Double > usedLatencyResults =
92- stats .latency .subList (resultSize / 10 , resultSize * 9 / 10 );
93-
94- results .add (
95- new BenchmarkMetric (
96- benchmarkModel ,
97- "avg_inference_latency(ms)" ,
98- stats .latency .stream ().mapToDouble (l -> l ).average ().orElse (0.0f ),
99- 0.0f ));
100- results .add (
101- new BenchmarkMetric (
102- benchmarkModel ,
103- "trimmean_inference_latency(ms)" ,
104- usedLatencyResults .stream ().mapToDouble (l -> l ).average ().orElse (0.0f ),
105- 0.0f ));
106- // Model load time
107- results .add (
108- new BenchmarkMetric (
109- benchmarkModel ,
110- "model_load_time(ms)" ,
111- (stats .loadEnd - stats .loadStart ) * 1e-6 ,
112- 0.0f ));
113- // Load status
114- results .add (new BenchmarkMetric (benchmarkModel , "load_status" , stats .errorCode , 0 ));
115- // RAM PSS usage
116- results .add (
117- new BenchmarkMetric (
118- benchmarkModel , "ram_pss_usage(mb)" , (Debug .getPss () - pssIdle ) / 1024 , 0 ));
119-
120- try (FileWriter writer = new FileWriter (getFilesDir () + "/benchmark_results.json" )) {
121- Gson gson = new Gson ();
122- writer .write (gson .toJson (results ));
123- } catch (IOException e ) {
124- e .printStackTrace ();
125- }
126- }
127- }.execute ();
105+ public BenchmarkHandler (Looper looper , BenchmarkActivity benchmarkActivity ) {
106+ super (looper );
107+ mModelRunner = new ModelRunner ();
108+ mBenchmarkActivity = benchmarkActivity ;
128109 }
129- }
130-
131- class Stats {
132- long loadStart ;
133- long loadEnd ;
134- List <Double > latency = new ArrayList <>();
135- int errorCode = 0 ;
136110
137111 @ Override
138- public String toString () {
139- return "latency: " + latency .stream ().map (Object ::toString ).collect (Collectors .joining ("" ));
112+ public void handleMessage (android .os .Message msg ) {
113+ if (msg .what == MESSAGE_RUN_BENCHMARK ) {
114+ mModelRunner .runBenchmark (
115+ mBenchmarkActivity .mModel ,
116+ mBenchmarkActivity .mNumWarmupIter ,
117+ mBenchmarkActivity .mNumIter ,
118+ mBenchmarkActivity .mResult );
119+
120+ if (mBenchmarkActivity .mTokenizerPath == null ) {
121+ mBenchmarkActivity .writeResult ();
122+ } else {
123+ this .sendEmptyMessage (MESSAGE_LLM_RUN_BENCHMARK );
124+ }
125+ } else if (msg .what == MESSAGE_LLM_RUN_BENCHMARK ) {
126+ mLlmBenchmark =
127+ new LlmBenchmark (
128+ mBenchmarkActivity ,
129+ mBenchmarkActivity .mModel .getPath (),
130+ mBenchmarkActivity .mTokenizerPath ,
131+ mBenchmarkActivity .mPrompt ,
132+ mBenchmarkActivity .mTemperature ,
133+ mBenchmarkActivity .mResult );
134+ }
140135 }
141136}
0 commit comments