@@ -18,10 +18,14 @@ import org.apache.commons.io.FileUtils
1818import org.json.JSONException
1919import org.json.JSONObject
2020import org.junit.Assert
21+ import org.junit.Assert.assertEquals
22+ import org.junit.Assert.assertThat
23+ import org.junit.Assert.assertTrue
2124import org.junit.Before
2225import org.junit.Rule
2326import org.junit.Test
2427import org.junit.runner.RunWith
28+ import org.pytorch.executorch.TestFileUtils.getTestFilePath
2529import org.pytorch.executorch.extension.llm.LlmCallback
2630import org.pytorch.executorch.extension.llm.LlmModule
2731
@@ -30,7 +34,7 @@ import org.pytorch.executorch.extension.llm.LlmModule
3034class LlmModuleInstrumentationTest : LlmCallback {
3135 private val results: MutableList <String > = ArrayList ()
3236 private val tokensPerSecond: MutableList <Float > = ArrayList ()
33- private var llmModule: LlmModule ? = null
37+ private lateinit var llmModule: LlmModule
3438
3539 @Before
3640 @Throws(IOException ::class )
@@ -57,25 +61,25 @@ class LlmModuleInstrumentationTest : LlmCallback {
5761 @Test
5862 @Throws(IOException ::class , URISyntaxException ::class )
5963 fun testGenerate () {
60- val loadResult = llmModule!! .load()
64+ val loadResult = llmModule.load()
6165 // Check that the model can be load successfully
62- Assert . assertEquals(OK .toLong(), loadResult.toLong())
66+ assertEquals(OK .toLong(), loadResult.toLong())
6367
64- llmModule!! .generate(TEST_PROMPT , SEQ_LEN , this @LlmModuleInstrumentationTest)
65- Assert . assertEquals(results.size.toLong(), SEQ_LEN .toLong())
66- Assert . assertTrue(tokensPerSecond[tokensPerSecond.size - 1 ] > 0 )
68+ llmModule.generate(TEST_PROMPT , SEQ_LEN , this @LlmModuleInstrumentationTest)
69+ assertEquals(results.size.toLong(), SEQ_LEN .toLong())
70+ assertTrue(tokensPerSecond[tokensPerSecond.size - 1 ] > 0 )
6771 }
6872
6973 @Test
7074 @Throws(IOException ::class , URISyntaxException ::class )
7175 fun testGenerateAndStop () {
72- llmModule!! .generate(
76+ llmModule.generate(
7377 TEST_PROMPT ,
7478 SEQ_LEN ,
7579 object : LlmCallback {
7680 override fun onResult (result : String ) {
7781 this @LlmModuleInstrumentationTest.onResult(result)
78- llmModule!! .stop()
82+ llmModule.stop()
7983 }
8084
8185 override fun onStats (stats : String ) {
@@ -85,7 +89,7 @@ class LlmModuleInstrumentationTest : LlmCallback {
8589 )
8690
8791 val stoppedResultSize = results.size
88- Assert . assertTrue(stoppedResultSize < SEQ_LEN )
92+ assertTrue(stoppedResultSize < SEQ_LEN )
8993 }
9094
9195 override fun onResult (result : String ) {
@@ -101,7 +105,8 @@ class LlmModuleInstrumentationTest : LlmCallback {
101105 val promptEvalEndMs = jsonObject.getInt(" prompt_eval_end_ms" )
102106 tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
103107 tokensPerSecond.add(tps)
104- } catch (_: JSONException ) {}
108+ } catch (_: JSONException ) {
109+ }
105110 }
106111
107112 companion object {
@@ -110,12 +115,5 @@ class LlmModuleInstrumentationTest : LlmCallback {
110115 private const val TEST_PROMPT = " Hello"
111116 private const val OK = 0x00
112117 private const val SEQ_LEN = 32
113-
114- private fun getTestFilePath (fileName : String ): String {
115- return InstrumentationRegistry .getInstrumentation()
116- .targetContext
117- .externalCacheDir
118- .toString() + fileName
119- }
120118 }
121119}
0 commit comments