Skip to content

Commit 6660276

Browse files
committed
fmt
1 parent 0dc2ba5 commit 6660276

File tree

4 files changed

+437
-442
lines changed

4 files changed

+437
-442
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt

Lines changed: 79 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import android.Manifest
1111
import androidx.test.InstrumentationRegistry
1212
import androidx.test.ext.junit.runners.AndroidJUnit4
1313
import androidx.test.rule.GrantPermissionRule
14+
import java.io.File
15+
import java.io.IOException
16+
import java.net.URISyntaxException
1417
import org.apache.commons.io.FileUtils
1518
import org.json.JSONException
1619
import org.json.JSONObject
@@ -21,102 +24,98 @@ import org.junit.Test
2124
import org.junit.runner.RunWith
2225
import org.pytorch.executorch.extension.llm.LlmCallback
2326
import org.pytorch.executorch.extension.llm.LlmModule
24-
import java.io.File
25-
import java.io.IOException
26-
import java.net.URISyntaxException
2727

2828
/** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */
2929
@RunWith(AndroidJUnit4::class)
3030
class LlmModuleInstrumentationTest : LlmCallback {
31-
private val results: MutableList<String> = ArrayList()
32-
private val tokensPerSecond: MutableList<Float> = ArrayList()
33-
private var llmModule: LlmModule? = null
31+
private val results: MutableList<String> = ArrayList()
32+
private val tokensPerSecond: MutableList<Float> = ArrayList()
33+
private var llmModule: LlmModule? = null
3434

35-
@Before
36-
@Throws(IOException::class)
37-
fun setUp() {
38-
// copy zipped test resources to local device
39-
val addPteFile = File(getTestFilePath(TEST_FILE_NAME))
40-
var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME)
41-
FileUtils.copyInputStreamToFile(inputStream, addPteFile)
42-
inputStream.close()
35+
@Before
36+
@Throws(IOException::class)
37+
fun setUp() {
38+
// copy zipped test resources to local device
39+
val addPteFile = File(getTestFilePath(TEST_FILE_NAME))
40+
var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME)
41+
FileUtils.copyInputStreamToFile(inputStream, addPteFile)
42+
inputStream.close()
4343

44-
val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME))
45-
inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)
46-
FileUtils.copyInputStreamToFile(inputStream, tokenizerFile)
47-
inputStream.close()
44+
val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME))
45+
inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME)
46+
FileUtils.copyInputStreamToFile(inputStream, tokenizerFile)
47+
inputStream.close()
4848

49-
llmModule =
50-
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
51-
}
49+
llmModule =
50+
LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f)
51+
}
5252

53-
@get:Rule
54-
var runtimePermissionRule: GrantPermissionRule =
55-
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
53+
@get:Rule
54+
var runtimePermissionRule: GrantPermissionRule =
55+
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
5656

57-
@Test
58-
@Throws(IOException::class, URISyntaxException::class)
59-
fun testGenerate() {
60-
val loadResult = llmModule!!.load()
61-
// Check that the model can be load successfully
62-
Assert.assertEquals(OK.toLong(), loadResult.toLong())
57+
@Test
58+
@Throws(IOException::class, URISyntaxException::class)
59+
fun testGenerate() {
60+
val loadResult = llmModule!!.load()
61+
// Check that the model can be load successfully
62+
Assert.assertEquals(OK.toLong(), loadResult.toLong())
6363

64-
llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
65-
Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong())
66-
Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0)
67-
}
64+
llmModule!!.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest)
65+
Assert.assertEquals(results.size.toLong(), SEQ_LEN.toLong())
66+
Assert.assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0)
67+
}
6868

69-
@Test
70-
@Throws(IOException::class, URISyntaxException::class)
71-
fun testGenerateAndStop() {
72-
llmModule!!.generate(
73-
TEST_PROMPT,
74-
SEQ_LEN,
75-
object : LlmCallback {
76-
override fun onResult(result: String) {
77-
this@LlmModuleInstrumentationTest.onResult(result)
78-
llmModule!!.stop()
79-
}
69+
@Test
70+
@Throws(IOException::class, URISyntaxException::class)
71+
fun testGenerateAndStop() {
72+
llmModule!!.generate(
73+
TEST_PROMPT,
74+
SEQ_LEN,
75+
object : LlmCallback {
76+
override fun onResult(result: String) {
77+
this@LlmModuleInstrumentationTest.onResult(result)
78+
llmModule!!.stop()
79+
}
8080

81-
override fun onStats(stats: String) {
82-
this@LlmModuleInstrumentationTest.onStats(stats)
83-
}
84-
},
85-
)
81+
override fun onStats(stats: String) {
82+
this@LlmModuleInstrumentationTest.onStats(stats)
83+
}
84+
},
85+
)
8686

87-
val stoppedResultSize = results.size
88-
Assert.assertTrue(stoppedResultSize < SEQ_LEN)
89-
}
87+
val stoppedResultSize = results.size
88+
Assert.assertTrue(stoppedResultSize < SEQ_LEN)
89+
}
9090

91-
override fun onResult(result: String) {
92-
results.add(result)
93-
}
91+
override fun onResult(result: String) {
92+
results.add(result)
93+
}
9494

95-
override fun onStats(stats: String) {
96-
var tps = 0f
97-
try {
98-
val jsonObject = JSONObject(stats)
99-
val numGeneratedTokens = jsonObject.getInt("generated_tokens")
100-
val inferenceEndMs = jsonObject.getInt("inference_end_ms")
101-
val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms")
102-
tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
103-
tokensPerSecond.add(tps)
104-
} catch (_: JSONException) {
105-
}
106-
}
95+
override fun onStats(stats: String) {
96+
var tps = 0f
97+
try {
98+
val jsonObject = JSONObject(stats)
99+
val numGeneratedTokens = jsonObject.getInt("generated_tokens")
100+
val inferenceEndMs = jsonObject.getInt("inference_end_ms")
101+
val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms")
102+
tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000
103+
tokensPerSecond.add(tps)
104+
} catch (_: JSONException) {}
105+
}
107106

108-
companion object {
109-
private const val TEST_FILE_NAME = "/stories.pte"
110-
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
111-
private const val TEST_PROMPT = "Hello"
112-
private const val OK = 0x00
113-
private const val SEQ_LEN = 32
107+
companion object {
108+
private const val TEST_FILE_NAME = "/stories.pte"
109+
private const val TOKENIZER_FILE_NAME = "/tokenizer.bin"
110+
private const val TEST_PROMPT = "Hello"
111+
private const val OK = 0x00
112+
private const val SEQ_LEN = 32
114113

115-
private fun getTestFilePath(fileName: String): String {
116-
return InstrumentationRegistry.getInstrumentation()
117-
.targetContext
118-
.externalCacheDir
119-
.toString() + fileName
120-
}
114+
private fun getTestFilePath(fileName: String): String {
115+
return InstrumentationRegistry.getInstrumentation()
116+
.targetContext
117+
.externalCacheDir
118+
.toString() + fileName
121119
}
120+
}
122121
}

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt

Lines changed: 75 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -13,104 +13,104 @@ import android.graphics.BitmapFactory
1313
import androidx.test.InstrumentationRegistry
1414
import androidx.test.ext.junit.runners.AndroidJUnit4
1515
import androidx.test.rule.GrantPermissionRule
16+
import java.io.File
17+
import java.io.IOException
18+
import java.net.URISyntaxException
1619
import org.apache.commons.io.FileUtils
1720
import org.junit.Assert
1821
import org.junit.Rule
1922
import org.junit.Test
2023
import org.junit.runner.RunWith
2124
import org.pytorch.executorch.TensorImageUtils.bitmapToFloat32Tensor
22-
import java.io.File
23-
import java.io.IOException
24-
import java.net.URISyntaxException
2525

2626
/** Unit tests for [Module]. */
2727
@RunWith(AndroidJUnit4::class)
2828
class ModuleE2ETest {
29-
@get:Rule
30-
var runtimePermissionRule: GrantPermissionRule =
31-
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
29+
@get:Rule
30+
var runtimePermissionRule: GrantPermissionRule =
31+
GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE)
3232

33-
@Throws(IOException::class, URISyntaxException::class)
34-
fun testClassification(filePath: String) {
35-
val pteFile = File(getTestFilePath(filePath))
36-
val inputStream = javaClass.getResourceAsStream(filePath)
37-
FileUtils.copyInputStreamToFile(inputStream, pteFile)
38-
inputStream.close()
33+
@Throws(IOException::class, URISyntaxException::class)
34+
fun testClassification(filePath: String) {
35+
val pteFile = File(getTestFilePath(filePath))
36+
val inputStream = javaClass.getResourceAsStream(filePath)
37+
FileUtils.copyInputStreamToFile(inputStream, pteFile)
38+
inputStream.close()
3939

40-
val imgInputStream = javaClass.getResourceAsStream("/banana.jpeg")
41-
var bitmap = BitmapFactory.decodeStream(imgInputStream)
42-
bitmap = Bitmap.createScaledBitmap(bitmap!!, 224, 224, true)
43-
imgInputStream.close()
40+
val imgInputStream = javaClass.getResourceAsStream("/banana.jpeg")
41+
var bitmap = BitmapFactory.decodeStream(imgInputStream)
42+
bitmap = Bitmap.createScaledBitmap(bitmap!!, 224, 224, true)
43+
imgInputStream.close()
4444

45-
val inputTensor =
46-
bitmapToFloat32Tensor(
47-
bitmap,
48-
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
49-
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
50-
)
45+
val inputTensor =
46+
bitmapToFloat32Tensor(
47+
bitmap,
48+
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
49+
TensorImageUtils.TORCHVISION_NORM_STD_RGB,
50+
)
5151

52-
val module = Module.load(getTestFilePath(filePath))
52+
val module = Module.load(getTestFilePath(filePath))
5353

54-
val results = module.forward(EValue.from(inputTensor))
55-
Assert.assertTrue(results[0].isTensor)
56-
val scores = results[0].toTensor().dataAsFloatArray
54+
val results = module.forward(EValue.from(inputTensor))
55+
Assert.assertTrue(results[0].isTensor)
56+
val scores = results[0].toTensor().dataAsFloatArray
5757

58-
val bananaClass = 954 // From ImageNet 1K
59-
Assert.assertEquals(bananaClass.toLong(), argmax(scores).toLong())
60-
}
58+
val bananaClass = 954 // From ImageNet 1K
59+
Assert.assertEquals(bananaClass.toLong(), argmax(scores).toLong())
60+
}
6161

62-
@Test
63-
@Throws(IOException::class, URISyntaxException::class)
64-
fun testXnnpackBackendRequired() {
65-
val pteFile = File(getTestFilePath("/mv3_xnnpack_fp32.pte"))
66-
val inputStream = javaClass.getResourceAsStream("/mv3_xnnpack_fp32.pte")
67-
FileUtils.copyInputStreamToFile(inputStream, pteFile)
68-
inputStream.close()
62+
@Test
63+
@Throws(IOException::class, URISyntaxException::class)
64+
fun testXnnpackBackendRequired() {
65+
val pteFile = File(getTestFilePath("/mv3_xnnpack_fp32.pte"))
66+
val inputStream = javaClass.getResourceAsStream("/mv3_xnnpack_fp32.pte")
67+
FileUtils.copyInputStreamToFile(inputStream, pteFile)
68+
inputStream.close()
6969

70-
val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"))
71-
Assert.assertArrayEquals(
72-
arrayOf("XnnpackBackend"),
73-
module.getMethodMetadata("forward").getBackends(),
74-
)
75-
}
70+
val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"))
71+
Assert.assertArrayEquals(
72+
arrayOf("XnnpackBackend"),
73+
module.getMethodMetadata("forward").getBackends(),
74+
)
75+
}
7676

77-
@Test
78-
@Throws(IOException::class, URISyntaxException::class)
79-
fun testMv2Fp32() {
80-
testClassification("/mv2_xnnpack_fp32.pte")
81-
}
77+
@Test
78+
@Throws(IOException::class, URISyntaxException::class)
79+
fun testMv2Fp32() {
80+
testClassification("/mv2_xnnpack_fp32.pte")
81+
}
8282

83-
@Test
84-
@Throws(IOException::class, URISyntaxException::class)
85-
fun testMv3Fp32() {
86-
testClassification("/mv3_xnnpack_fp32.pte")
87-
}
83+
@Test
84+
@Throws(IOException::class, URISyntaxException::class)
85+
fun testMv3Fp32() {
86+
testClassification("/mv3_xnnpack_fp32.pte")
87+
}
8888

89-
@Test
90-
@Throws(IOException::class, URISyntaxException::class)
91-
fun testResnet50() {
92-
testClassification("/resnet50_xnnpack_q8.pte")
93-
}
89+
@Test
90+
@Throws(IOException::class, URISyntaxException::class)
91+
fun testResnet50() {
92+
testClassification("/resnet50_xnnpack_q8.pte")
93+
}
9494

95-
companion object {
96-
private fun getTestFilePath(fileName: String): String {
97-
return InstrumentationRegistry.getInstrumentation()
98-
.targetContext
99-
.externalCacheDir
100-
.toString() + fileName
101-
}
95+
companion object {
96+
private fun getTestFilePath(fileName: String): String {
97+
return InstrumentationRegistry.getInstrumentation()
98+
.targetContext
99+
.externalCacheDir
100+
.toString() + fileName
101+
}
102102

103-
fun argmax(array: FloatArray): Int {
104-
require(array.isNotEmpty()) { "Array cannot be empty" }
105-
var maxIndex = 0
106-
var maxValue = array[0]
107-
for (i in 1 until array.size) {
108-
if (array[i] > maxValue) {
109-
maxValue = array[i]
110-
maxIndex = i
111-
}
112-
}
113-
return maxIndex
103+
fun argmax(array: FloatArray): Int {
104+
require(array.isNotEmpty()) { "Array cannot be empty" }
105+
var maxIndex = 0
106+
var maxValue = array[0]
107+
for (i in 1 until array.size) {
108+
if (array[i] > maxValue) {
109+
maxValue = array[i]
110+
maxIndex = i
114111
}
112+
}
113+
return maxIndex
115114
}
115+
}
116116
}

0 commit comments

Comments
 (0)