88
99package com .example .executorchllamademo ;
1010
11- import static junit .framework .TestCase .assertTrue ;
1211import static org .junit .Assert .assertEquals ;
1312import static org .junit .Assert .assertFalse ;
1413
14+ import android .os .Bundle ;
1515import androidx .test .ext .junit .runners .AndroidJUnit4 ;
16+ import androidx .test .platform .app .InstrumentationRegistry ;
17+ import java .io .File ;
1618import java .util .ArrayList ;
19+ import java .util .Arrays ;
1720import java .util .List ;
1821import org .junit .Test ;
1922import org .junit .runner .RunWith ;
2427public class PerfTest implements LlamaCallback {
2528
2629 private static final String RESOURCE_PATH = "/data/local/tmp/llama/" ;
27- private static final String MODEL_NAME = "xnnpack_llama2.pte" ;
2830 private static final String TOKENIZER_BIN = "tokenizer.bin" ;
2931
30- // From https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md
31- private static final Float EXPECTED_TPS = 10.0F ;
32-
3332 private final List <String > results = new ArrayList <>();
3433 private final List <Float > tokensPerSecond = new ArrayList <>();
3534
3635 @ Test
3736 public void testTokensPerSecond () {
38- String modelPath = RESOURCE_PATH + MODEL_NAME ;
3937 String tokenizerPath = RESOURCE_PATH + TOKENIZER_BIN ;
40- LlamaModule mModule = new LlamaModule (modelPath , tokenizerPath , 0.8f );
38+ // Find out the model name
39+ File directory = new File (RESOURCE_PATH );
40+ Arrays .stream (directory .listFiles ())
41+ .filter (file -> file .getName ().endsWith (".pte" ))
42+ .forEach (
43+ model -> {
44+ LlamaModule mModule = new LlamaModule (model .getPath (), tokenizerPath , 0.8f );
45+ // Print the model name because there might be more than one of them
46+ report ("ModelName" , model .getName ());
4147
42- int loadResult = mModule .load ();
43- // Check that the model can be load successfully
44- assertEquals (0 , loadResult );
48+ int loadResult = mModule .load ();
49+ // Check that the model can be load successfully
50+ assertEquals (0 , loadResult );
4551
46- // Run a testing prompt
47- mModule .generate ("How do you do! I'm testing llama2 on mobile device" , PerfTest .this );
48- assertFalse (tokensPerSecond .isEmpty ());
52+ // Run a testing prompt
53+ mModule .generate ("How do you do! I'm testing llama2 on mobile device" , PerfTest .this );
54+ assertFalse (tokensPerSecond .isEmpty ());
4955
50- final Float tps = tokensPerSecond .get (tokensPerSecond .size () - 1 );
51- assertTrue (
52- "The observed TPS " + tps + " is less than the expected TPS " + EXPECTED_TPS ,
53- tps >= EXPECTED_TPS );
56+ final Float tps = tokensPerSecond .get (tokensPerSecond .size () - 1 );
57+ report ("TPS" , tps );
58+ });
5459 }
5560
5661 @ Override
@@ -62,4 +67,16 @@ public void onResult(String result) {
6267 public void onStats (float tps ) {
6368 tokensPerSecond .add (tps );
6469 }
70+
71+ private void report (final String metric , final Float value ) {
72+ Bundle bundle = new Bundle ();
73+ bundle .putFloat (metric , value );
74+ InstrumentationRegistry .getInstrumentation ().sendStatus (0 , bundle );
75+ }
76+
77+ private void report (final String key , final String value ) {
78+ Bundle bundle = new Bundle ();
79+ bundle .putString (key , value );
80+ InstrumentationRegistry .getInstrumentation ().sendStatus (0 , bundle );
81+ }
6582}
0 commit comments