Skip to content

Commit 7b53b88

Browse files
phaitingHaiting Pu
andauthored
Use lateinit var to remove !! and created a common TestFileUtils (#11155)
### Summary This change did two changes: * Use lateinit var to remove !! * Created a common TestFileUtils to share the same code for getTestFilePath across all instrumentation test code. ### Test plan ./gradlew :executorch_android:connectedAndroidTest Co-authored-by: Haiting Pu <[email protected]>
1 parent 7b1374c commit 7b53b88

File tree

4 files changed

+33
-30
lines changed

4 files changed

+33
-30
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@ import org.apache.commons.io.FileUtils
1818
import org.json.JSONException
1919
import org.json.JSONObject
2020
import org.junit.Assert
21+
import org.junit.Assert.assertEquals
22+
import org.junit.Assert.assertThat
23+
import org.junit.Assert.assertTrue
2124
import org.junit.Before
2225
import org.junit.Rule
2326
import org.junit.Test
2427
import org.junit.runner.RunWith
28+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
2529
import org.pytorch.executorch.extension.llm.LlmCallback
2630
import org.pytorch.executorch.extension.llm.LlmModule
2731

@@ -30,7 +34,7 @@ import org.pytorch.executorch.extension.llm.LlmModule
3034
class LlmModuleInstrumentationTest : LlmCallback {
3135
private val results: MutableList<String> = ArrayList()
3236
private val tokensPerSecond: MutableList<Float> = ArrayList()
33-
private var llmModule: LlmModule? = null
37+
private lateinit var llmModule: LlmModule
3438

3539
@Before
3640
@Throws(IOException::class)
@@ -57,25 +61,25 @@ class LlmModuleInstrumentationTest : LlmCallback {
5761
@Test
5862
@Throws(IOException::class, URISyntaxException::class)
5963
fun testGenerate() {
60-
val loadResult = llmModule!!.load()
64+
val loadResult = llmModule.load()
6165
// Check that the model can be load successfully
62-
Assert.assertEquals(OK.toLong(), loadResult.toLong())
66+
assertEquals(OK.toLong(), loadResult.toLong())
6367

64-
llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
65-
Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong())
66-
Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0)
68+
llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
69+
assertEquals(results.size.toLong(), SEQ_LEN.toLong())
70+
assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0)
6771
}
6872

6973
@Test
7074
@Throws(IOException::class, URISyntaxException::class)
7175
fun testGenerateAndStop() {
72-
llmModule!!.generate(
76+
llmModule.generate(
7377
TEST_PROMPT,
7478
SEQ_LEN,
7579
object : LlmCallback {
7680
override fun onResult(result: String) {
7781
this@LlmModuleInstrumentationTest.onResult(result)
78-
llmModule!!.stop()
82+
llmModule.stop()
7983
}
8084

8185
override fun onStats(stats: String) {
@@ -85,7 +89,7 @@ class LlmModuleInstrumentationTest : LlmCallback {
8589
)
8690

8791
val stoppedResultSize = results.size
88-
Assert.assertTrue(stoppedResultSize < SEQ_LEN)
92+
assertTrue(stoppedResultSize < SEQ_LEN)
8993
}
9094

9195
override fun onResult(result: String) {
@@ -101,7 +105,8 @@ class LlmModuleInstrumentationTest : LlmCallback {
101105
val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms")
102106
tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
103107
tokensPerSecond.add(tps)
104-
} catch (_: JSONException) {}
108+
} catch (_: JSONException) {
109+
}
105110
}
106111

107112
companion object {
@@ -110,12 +115,5 @@ class LlmModuleInstrumentationTest : LlmCallback {
110115
private const val TEST_PROMPT = "Hello"
111116
private const val OK = 0x00
112117
private const val SEQ_LEN = 32
113-
114-
private fun getTestFilePath(fileName: String): String {
115-
return InstrumentationRegistry.getInstrumentation()
116-
.targetContext
117-
.externalCacheDir
118-
.toString() + fileName
119-
}
120118
}
121119
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.junit.Rule
2222
import org.junit.Test
2323
import org.junit.runner.RunWith
2424
import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor
25+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
2526

2627
/** Unit tests for [Module]. */
2728
@RunWith(AndroidJUnit4::class)
@@ -90,12 +91,6 @@ class ModuleE2ETest {
9091
}
9192

9293
companion object {
93-
private fun getTestFilePath(fileName: String): String {
94-
return InstrumentationRegistry.getInstrumentation()
95-
.targetContext
96-
.externalCacheDir
97-
.toString() + fileName
98-
}
9994

10095
fun argmax(array: FloatArray): Int {
10196
require(array.isNotEmpty()) { "Array cannot be empty" }

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.junit.Before
2323
import org.junit.Rule
2424
import org.junit.Test
2525
import org.junit.runner.RunWith
26+
import org.pytorch.executorch.TestFileUtils.getTestFilePath
2627

2728
/** Unit tests for [Module]. */
2829
@RunWith(AndroidJUnit4::class)
@@ -173,12 +174,5 @@ class ModuleInstrumentationTest {
173174
private const val INVALID_STATE = 0x2
174175
private const val INVALID_ARGUMENT = 0x12
175176
private const val ACCESS_FAILED = 0x22
176-
177-
private fun getTestFilePath(fileName: String): String {
178-
return InstrumentationRegistry.getInstrumentation()
179-
.targetContext
180-
.externalCacheDir
181-
.toString() + fileName
182-
}
183177
}
184178
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package org.pytorch.executorch
2+
3+
import androidx.test.InstrumentationRegistry
4+
5+
/**
6+
* Test File Utils
7+
*/
8+
object TestFileUtils {
9+
10+
fun getTestFilePath(fileName: String): String {
11+
return InstrumentationRegistry.getInstrumentation()
12+
.targetContext
13+
.externalCacheDir
14+
.toString() + fileName
15+
}
16+
}

0 commit comments

Comments
 (0)