@@ -11,6 +11,9 @@ import android.Manifest
1111import androidx.test.InstrumentationRegistry
1212import androidx.test.ext.junit.runners.AndroidJUnit4
1313import androidx.test.rule.GrantPermissionRule
14+ import java.io.File
15+ import java.io.IOException
16+ import java.net.URISyntaxException
1417import org.apache.commons.io.FileUtils
1518import org.json.JSONException
1619import org.json.JSONObject
@@ -21,102 +24,98 @@ import org.junit.Test
2124import org.junit.runner.RunWith
2225import org.pytorch.executorch.extension.llm.LlmCallback
2326import org.pytorch.executorch.extension.llm.LlmModule
24- import java.io.File
25- import java.io.IOException
26- import java.net.URISyntaxException
2727
2828/* * Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */
2929@RunWith(AndroidJUnit4 ::class )
3030class LlmModuleInstrumentationTest : LlmCallback {
31- private val results: MutableList <String > = ArrayList ()
32- private val tokensPerSecond: MutableList <Float > = ArrayList ()
33- private var llmModule: LlmModule ? = null
31+ private val results: MutableList <String > = ArrayList ()
32+ private val tokensPerSecond: MutableList <Float > = ArrayList ()
33+ private var llmModule: LlmModule ? = null
3434
35- @Before
36- @Throws(IOException ::class )
37- fun setUp () {
38- // copy zipped test resources to local device
39- val addPteFile = File (getTestFilePath(TEST_FILE_NAME ))
40- var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME )
41- FileUtils .copyInputStreamToFile(inputStream, addPteFile)
42- inputStream.close()
35+ @Before
36+ @Throws(IOException ::class )
37+ fun setUp () {
38+ // copy zipped test resources to local device
39+ val addPteFile = File (getTestFilePath(TEST_FILE_NAME ))
40+ var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME )
41+ FileUtils .copyInputStreamToFile(inputStream, addPteFile)
42+ inputStream.close()
4343
44- val tokenizerFile = File (getTestFilePath(TOKENIZER_FILE_NAME ))
45- inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME )
46- FileUtils .copyInputStreamToFile(inputStream, tokenizerFile)
47- inputStream.close()
44+ val tokenizerFile = File (getTestFilePath(TOKENIZER_FILE_NAME ))
45+ inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME )
46+ FileUtils .copyInputStreamToFile(inputStream, tokenizerFile)
47+ inputStream.close()
4848
49- llmModule =
50- LlmModule (getTestFilePath(TEST_FILE_NAME ), getTestFilePath(TOKENIZER_FILE_NAME ), 0.0f )
51- }
49+ llmModule =
50+ LlmModule (getTestFilePath(TEST_FILE_NAME ), getTestFilePath(TOKENIZER_FILE_NAME ), 0.0f )
51+ }
5252
53- @get:Rule
54- var runtimePermissionRule: GrantPermissionRule =
55- GrantPermissionRule .grant(Manifest .permission.READ_EXTERNAL_STORAGE )
53+ @get:Rule
54+ var runtimePermissionRule: GrantPermissionRule =
55+ GrantPermissionRule .grant(Manifest .permission.READ_EXTERNAL_STORAGE )
5656
57- @Test
58- @Throws(IOException ::class , URISyntaxException ::class )
59- fun testGenerate () {
60- val loadResult = llmModule!! .load()
61- // Check that the model can be load successfully
62- Assert .assertEquals(OK .toLong(), loadResult.toLong())
57+ @Test
58+ @Throws(IOException ::class , URISyntaxException ::class )
59+ fun testGenerate () {
60+ val loadResult = llmModule!! .load()
61+ // Check that the model can be load successfully
62+ Assert .assertEquals(OK .toLong(), loadResult.toLong())
6363
64- llmModule!! .generate(TEST_PROMPT , SEQ_LEN , this @LlmModuleInstrumentationTest)
65- Assert .assertEquals(results.size.toLong(), SEQ_LEN .toLong())
66- Assert .assertTrue(tokensPerSecond[tokensPerSecond.size - 1 ] > 0 )
67- }
64+ llmModule!! .generate(TEST_PROMPT , SEQ_LEN , this @LlmModuleInstrumentationTest)
65+ Assert .assertEquals(results.size.toLong(), SEQ_LEN .toLong())
66+ Assert .assertTrue(tokensPerSecond[tokensPerSecond.size - 1 ] > 0 )
67+ }
6868
69- @Test
70- @Throws(IOException ::class , URISyntaxException ::class )
71- fun testGenerateAndStop () {
72- llmModule!! .generate(
73- TEST_PROMPT ,
74- SEQ_LEN ,
75- object : LlmCallback {
76- override fun onResult (result : String ) {
77- this @LlmModuleInstrumentationTest.onResult(result)
78- llmModule!! .stop()
79- }
69+ @Test
70+ @Throws(IOException ::class , URISyntaxException ::class )
71+ fun testGenerateAndStop () {
72+ llmModule!! .generate(
73+ TEST_PROMPT ,
74+ SEQ_LEN ,
75+ object : LlmCallback {
76+ override fun onResult (result : String ) {
77+ this @LlmModuleInstrumentationTest.onResult(result)
78+ llmModule!! .stop()
79+ }
8080
81- override fun onStats (stats : String ) {
82- this @LlmModuleInstrumentationTest.onStats(stats)
83- }
84- },
85- )
81+ override fun onStats (stats : String ) {
82+ this @LlmModuleInstrumentationTest.onStats(stats)
83+ }
84+ },
85+ )
8686
87- val stoppedResultSize = results.size
88- Assert .assertTrue(stoppedResultSize < SEQ_LEN )
89- }
87+ val stoppedResultSize = results.size
88+ Assert .assertTrue(stoppedResultSize < SEQ_LEN )
89+ }
9090
91- override fun onResult (result : String ) {
92- results.add(result)
93- }
91+ override fun onResult (result : String ) {
92+ results.add(result)
93+ }
9494
95- override fun onStats (stats : String ) {
96- var tps = 0f
97- try {
98- val jsonObject = JSONObject (stats)
99- val numGeneratedTokens = jsonObject.getInt(" generated_tokens" )
100- val inferenceEndMs = jsonObject.getInt(" inference_end_ms" )
101- val promptEvalEndMs = jsonObject.getInt(" prompt_eval_end_ms" )
102- tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
103- tokensPerSecond.add(tps)
104- } catch (_: JSONException ) {
105- }
106- }
95+ override fun onStats (stats : String ) {
96+ var tps = 0f
97+ try {
98+ val jsonObject = JSONObject (stats)
99+ val numGeneratedTokens = jsonObject.getInt(" generated_tokens" )
100+ val inferenceEndMs = jsonObject.getInt(" inference_end_ms" )
101+ val promptEvalEndMs = jsonObject.getInt(" prompt_eval_end_ms" )
102+ tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
103+ tokensPerSecond.add(tps)
104+ } catch (_: JSONException ) {}
105+ }
107106
108- companion object {
109- private const val TEST_FILE_NAME = " /stories.pte"
110- private const val TOKENIZER_FILE_NAME = " /tokenizer.bin"
111- private const val TEST_PROMPT = " Hello"
112- private const val OK = 0x00
113- private const val SEQ_LEN = 32
107+ companion object {
108+ private const val TEST_FILE_NAME = " /stories.pte"
109+ private const val TOKENIZER_FILE_NAME = " /tokenizer.bin"
110+ private const val TEST_PROMPT = " Hello"
111+ private const val OK = 0x00
112+ private const val SEQ_LEN = 32
114113
115- private fun getTestFilePath (fileName : String ): String {
116- return InstrumentationRegistry .getInstrumentation()
117- .targetContext
118- .externalCacheDir
119- .toString() + fileName
120- }
114+ private fun getTestFilePath (fileName : String ): String {
115+ return InstrumentationRegistry .getInstrumentation()
116+ .targetContext
117+ .externalCacheDir
118+ .toString() + fileName
121119 }
120+ }
122121}
0 commit comments