88package org.pytorch.executorch
99
1010import android.Manifest
11- import androidx.test.InstrumentationRegistry
1211import androidx.test.ext.junit.runners.AndroidJUnit4
1312import androidx.test.rule.GrantPermissionRule
1413import java.io.File
@@ -17,9 +16,7 @@ import java.net.URISyntaxException
1716import org.apache.commons.io.FileUtils
1817import org.json.JSONException
1918import org.json.JSONObject
20- import org.junit.Assert
2119import org.junit.Assert.assertEquals
22- import org.junit.Assert.assertThat
2320import org.junit.Assert.assertTrue
2421import org.junit.Before
2522import org.junit.Rule
@@ -32,88 +29,87 @@ import org.pytorch.executorch.extension.llm.LlmModule
3229/* * Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */
3330@RunWith(AndroidJUnit4 ::class )
3431class LlmModuleInstrumentationTest : LlmCallback {
35- private val results: MutableList <String > = ArrayList ()
36- private val tokensPerSecond: MutableList <Float > = ArrayList ()
37- private lateinit var llmModule: LlmModule
32+ private val results: MutableList <String > = ArrayList ()
33+ private val tokensPerSecond: MutableList <Float > = ArrayList ()
34+ private lateinit var llmModule: LlmModule
3835
39- @Before
40- @Throws(IOException ::class )
41- fun setUp () {
42- // copy zipped test resources to local device
43- val addPteFile = File (getTestFilePath(TEST_FILE_NAME ))
44- var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME )
45- FileUtils .copyInputStreamToFile(inputStream, addPteFile)
46- inputStream.close()
36+ @Before
37+ @Throws(IOException ::class )
38+ fun setUp () {
39+ // copy zipped test resources to local device
40+ val addPteFile = File (getTestFilePath(TEST_FILE_NAME ))
41+ var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME )
42+ FileUtils .copyInputStreamToFile(inputStream, addPteFile)
43+ inputStream.close()
4744
48- val tokenizerFile = File (getTestFilePath(TOKENIZER_FILE_NAME ))
49- inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME )
50- FileUtils .copyInputStreamToFile(inputStream, tokenizerFile)
51- inputStream.close()
45+ val tokenizerFile = File (getTestFilePath(TOKENIZER_FILE_NAME ))
46+ inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME )
47+ FileUtils .copyInputStreamToFile(inputStream, tokenizerFile)
48+ inputStream.close()
5249
53- llmModule =
54- LlmModule (getTestFilePath(TEST_FILE_NAME ), getTestFilePath(TOKENIZER_FILE_NAME ), 0.0f )
55- }
50+ llmModule =
51+ LlmModule (getTestFilePath(TEST_FILE_NAME ), getTestFilePath(TOKENIZER_FILE_NAME ), 0.0f )
52+ }
5653
57- @get:Rule
58- var runtimePermissionRule: GrantPermissionRule =
59- GrantPermissionRule .grant(Manifest .permission.READ_EXTERNAL_STORAGE )
54+ @get:Rule
55+ var runtimePermissionRule: GrantPermissionRule =
56+ GrantPermissionRule .grant(Manifest .permission.READ_EXTERNAL_STORAGE )
6057
61- @Test
62- @Throws(IOException ::class , URISyntaxException ::class )
63- fun testGenerate () {
64- val loadResult = llmModule.load()
65- // Check that the model can be load successfully
66- assertEquals(OK .toLong(), loadResult.toLong())
58+ @Test
59+ @Throws(IOException ::class , URISyntaxException ::class )
60+ fun testGenerate () {
61+ val loadResult = llmModule.load()
62+ // Check that the model can be load successfully
63+ assertEquals(OK .toLong(), loadResult.toLong())
6764
68- llmModule.generate(TEST_PROMPT , SEQ_LEN , this @LlmModuleInstrumentationTest)
69- assertEquals(results.size.toLong(), SEQ_LEN .toLong())
70- assertTrue(tokensPerSecond[tokensPerSecond.size - 1 ] > 0 )
71- }
65+ llmModule.generate(TEST_PROMPT , SEQ_LEN , this @LlmModuleInstrumentationTest)
66+ assertEquals(results.size.toLong(), SEQ_LEN .toLong())
67+ assertTrue(tokensPerSecond[tokensPerSecond.size - 1 ] > 0 )
68+ }
7269
73- @Test
74- @Throws(IOException ::class , URISyntaxException ::class )
75- fun testGenerateAndStop () {
76- llmModule.generate(
77- TEST_PROMPT ,
78- SEQ_LEN ,
79- object : LlmCallback {
80- override fun onResult (result : String ) {
81- this @LlmModuleInstrumentationTest.onResult(result)
82- llmModule.stop()
83- }
70+ @Test
71+ @Throws(IOException ::class , URISyntaxException ::class )
72+ fun testGenerateAndStop () {
73+ llmModule.generate(
74+ TEST_PROMPT ,
75+ SEQ_LEN ,
76+ object : LlmCallback {
77+ override fun onResult (result : String ) {
78+ this @LlmModuleInstrumentationTest.onResult(result)
79+ llmModule.stop()
80+ }
8481
85- override fun onStats (stats : String ) {
86- this @LlmModuleInstrumentationTest.onStats(stats)
87- }
88- },
89- )
82+ override fun onStats (stats : String ) {
83+ this @LlmModuleInstrumentationTest.onStats(stats)
84+ }
85+ },
86+ )
9087
91- val stoppedResultSize = results.size
92- assertTrue(stoppedResultSize < SEQ_LEN )
93- }
88+ val stoppedResultSize = results.size
89+ assertTrue(stoppedResultSize < SEQ_LEN )
90+ }
9491
95- override fun onResult (result : String ) {
96- results.add(result)
97- }
92+ override fun onResult (result : String ) {
93+ results.add(result)
94+ }
9895
99- override fun onStats (stats : String ) {
100- var tps = 0f
101- try {
102- val jsonObject = JSONObject (stats)
103- val numGeneratedTokens = jsonObject.getInt(" generated_tokens" )
104- val inferenceEndMs = jsonObject.getInt(" inference_end_ms" )
105- val promptEvalEndMs = jsonObject.getInt(" prompt_eval_end_ms" )
106- tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
107- tokensPerSecond.add(tps)
108- } catch (_: JSONException ) {
109- }
110- }
96+ override fun onStats (stats : String ) {
97+ var tps = 0f
98+ try {
99+ val jsonObject = JSONObject (stats)
100+ val numGeneratedTokens = jsonObject.getInt(" generated_tokens" )
101+ val inferenceEndMs = jsonObject.getInt(" inference_end_ms" )
102+ val promptEvalEndMs = jsonObject.getInt(" prompt_eval_end_ms" )
103+ tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
104+ tokensPerSecond.add(tps)
105+ } catch (_: JSONException ) {}
106+ }
111107
112- companion object {
113- private const val TEST_FILE_NAME = " /stories.pte"
114- private const val TOKENIZER_FILE_NAME = " /tokenizer.bin"
115- private const val TEST_PROMPT = " Hello"
116- private const val OK = 0x00
117- private const val SEQ_LEN = 32
118- }
108+ companion object {
109+ private const val TEST_FILE_NAME = " /stories.pte"
110+ private const val TOKENIZER_FILE_NAME = " /tokenizer.bin"
111+ private const val TEST_PROMPT = " Hello"
112+ private const val OK = 0x00
113+ private const val SEQ_LEN = 32
114+ }
119115}
0 commit comments