diff --git a/extension/android/executorch_android/build.gradle b/extension/android/executorch_android/build.gradle index 0c18d60721e..bd8e7ed30e3 100644 --- a/extension/android/executorch_android/build.gradle +++ b/extension/android/executorch_android/build.gradle @@ -10,9 +10,17 @@ plugins { id "com.android.library" version "8.9.0" id "com.vanniktech.maven.publish" version "0.31.0" + id 'com.diffplug.spotless' version '8.0.0' alias(libs.plugins.jetbrains.kotlin.android) } +spotless { + kotlin { + target '**/*.kt' + ktfmt() + } +} + def qnnVersion = System.properties['qnnVersion'] def execuTorchVersion = System.properties['execuTorchVersion'] def flavor = System.properties['flavor'] @@ -37,7 +45,7 @@ android { jniLibs.srcDirs = ['../../../cmake-out-android-so/'] } androidTest { - resources.srcDirs += [ 'src/androidTest/resources' ] + resources.srcDirs += ['src/androidTest/resources'] } } kotlinOptions { @@ -46,7 +54,7 @@ android { } task copyTestRes(type: Exec) { - commandLine 'bash', 'android_test_setup.sh' + commandLine 'bash', 'android_test_setup.sh' } dependencies { @@ -67,36 +75,36 @@ dependencies { } mavenPublishing { - publishToMavenCentral() - signAllPublications() - - coordinates("org.pytorch", "executorch-android" + (flavor ? "-" + flavor : ""), execuTorchVersion ? execuTorchVersion : "1.0.0-SNAPSHOT") - - pom { - name = "ExecuTorch Android" - description = "ExecuTorch Android API" - inceptionYear = "2025" - url = "https://github.com/pytorch/executorch/" - licenses { - license { - name = "BSD 3-Clause" - url = "https://github.com/pytorch/executorch/blob/main/LICENSE" - distribution = "https://github.com/pytorch/executorch/blob/main/LICENSE" - } - } - developers { - developer { - id = "pytorch" - name = "pytorch" + publishToMavenCentral() + signAllPublications() + + coordinates("org.pytorch", "executorch-android" + (flavor ? "-" + flavor : ""), execuTorchVersion ? execuTorchVersion : "1.0.0-SNAPSHOT") + + pom { + name = "ExecuTorch Android" + description = "ExecuTorch Android API" + inceptionYear = "2025" url = "https://github.com/pytorch/executorch/" - } - } - scm { - url = "https://github.com/pytorch/executorch.git" - connection = "scm:git:https://github.com/pytorch/executorch" - developerConnection = "scm:git:git@github.com:pytorch/executorch.git" + licenses { + license { + name = "BSD 3-Clause" + url = "https://github.com/pytorch/executorch/blob/main/LICENSE" + distribution = "https://github.com/pytorch/executorch/blob/main/LICENSE" + } + } + developers { + developer { + id = "pytorch" + name = "pytorch" + url = "https://github.com/pytorch/executorch/" + } + } + scm { + url = "https://github.com/pytorch/executorch.git" + connection = "scm:git:https://github.com/pytorch/executorch" + developerConnection = "scm:git:git@github.com:pytorch/executorch.git" + } } - } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt index 2df45f14985..b6b314de447 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.kt @@ -8,7 +8,6 @@ package org.pytorch.executorch import android.Manifest -import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule import java.io.File @@ -17,9 +16,7 @@ import java.net.URISyntaxException import org.apache.commons.io.FileUtils import org.json.JSONException import org.json.JSONObject -import org.junit.Assert import org.junit.Assert.assertEquals -import org.junit.Assert.assertThat import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Rule @@ -32,88 +29,87 @@ import org.pytorch.executorch.extension.llm.LlmModule /** Unit tests for [org.pytorch.executorch.extension.llm.LlmModule]. */ @RunWith(AndroidJUnit4::class) class LlmModuleInstrumentationTest : LlmCallback { - private val results: MutableList = ArrayList() - private val tokensPerSecond: MutableList = ArrayList() - private lateinit var llmModule: LlmModule + private val results: MutableList = ArrayList() + private val tokensPerSecond: MutableList = ArrayList() + private lateinit var llmModule: LlmModule - @Before - @Throws(IOException::class) - fun setUp() { - // copy zipped test resources to local device - val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) - var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, addPteFile) - inputStream.close() + @Before + @Throws(IOException::class) + fun setUp() { + // copy zipped test resources to local device + val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) + var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, addPteFile) + inputStream.close() - val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) - inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, tokenizerFile) - inputStream.close() + val tokenizerFile = File(getTestFilePath(TOKENIZER_FILE_NAME)) + inputStream = javaClass.getResourceAsStream(TOKENIZER_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, tokenizerFile) + inputStream.close() - llmModule = - LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) - } + llmModule = + LlmModule(getTestFilePath(TEST_FILE_NAME), getTestFilePath(TOKENIZER_FILE_NAME), 0.0f) + } - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testGenerate() { - val loadResult = llmModule.load() - // Check that the model can be load successfully - assertEquals(OK.toLong(), loadResult.toLong()) + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testGenerate() { + val loadResult = llmModule.load() + // Check that the model can be load successfully + assertEquals(OK.toLong(), loadResult.toLong()) - llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) - assertEquals(results.size.toLong(), SEQ_LEN.toLong()) - assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0) - } + llmModule.generate(TEST_PROMPT, SEQ_LEN, this@LlmModuleInstrumentationTest) + assertEquals(results.size.toLong(), SEQ_LEN.toLong()) + assertTrue(tokensPerSecond[tokensPerSecond.size - 1] > 0) + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testGenerateAndStop() { - llmModule.generate( - TEST_PROMPT, - SEQ_LEN, - object : LlmCallback { - override fun onResult(result: String) { - this@LlmModuleInstrumentationTest.onResult(result) - llmModule.stop() - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testGenerateAndStop() { + llmModule.generate( + TEST_PROMPT, + SEQ_LEN, + object : LlmCallback { + override fun onResult(result: String) { + this@LlmModuleInstrumentationTest.onResult(result) + llmModule.stop() + } - override fun onStats(stats: String) { - this@LlmModuleInstrumentationTest.onStats(stats) - } - }, - ) + override fun onStats(stats: String) { + this@LlmModuleInstrumentationTest.onStats(stats) + } + }, + ) - val stoppedResultSize = results.size - assertTrue(stoppedResultSize < SEQ_LEN) - } + val stoppedResultSize = results.size + assertTrue(stoppedResultSize < SEQ_LEN) + } - override fun onResult(result: String) { - results.add(result) - } + override fun onResult(result: String) { + results.add(result) + } - override fun onStats(stats: String) { - var tps = 0f - try { - val jsonObject = JSONObject(stats) - val numGeneratedTokens = jsonObject.getInt("generated_tokens") - val inferenceEndMs = jsonObject.getInt("inference_end_ms") - val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") - tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 - tokensPerSecond.add(tps) - } catch (_: JSONException) { - } - } + override fun onStats(stats: String) { + var tps = 0f + try { + val jsonObject = JSONObject(stats) + val numGeneratedTokens = jsonObject.getInt("generated_tokens") + val inferenceEndMs = jsonObject.getInt("inference_end_ms") + val promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms") + tps = numGeneratedTokens.toFloat() / (inferenceEndMs - promptEvalEndMs) * 1000 + tokensPerSecond.add(tps) + } catch (_: JSONException) {} + } - companion object { - private const val TEST_FILE_NAME = "/stories.pte" - private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" - private const val TEST_PROMPT = "Hello" - private const val OK = 0x00 - private const val SEQ_LEN = 32 - } + companion object { + private const val TEST_FILE_NAME = "/stories.pte" + private const val TOKENIZER_FILE_NAME = "/tokenizer.bin" + private const val TEST_PROMPT = "Hello" + private const val OK = 0x00 + private const val SEQ_LEN = 32 + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt index 45476dac43f..3d688391673 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt @@ -27,83 +27,83 @@ import org.pytorch.executorch.TestFileUtils.getTestFilePath /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) class ModuleE2ETest { - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - @Throws(IOException::class, URISyntaxException::class) - fun testClassification(filePath: String) { - val pteFile = File(getTestFilePath(filePath)) - val inputStream = javaClass.getResourceAsStream(filePath) - FileUtils.copyInputStreamToFile(inputStream, pteFile) - inputStream.close() + @Throws(IOException::class, URISyntaxException::class) + fun testClassification(filePath: String) { + val pteFile = File(getTestFilePath(filePath)) + val inputStream = javaClass.getResourceAsStream(filePath) + FileUtils.copyInputStreamToFile(inputStream, pteFile) + inputStream.close() - val imgInputStream = javaClass.getResourceAsStream("/banana.jpeg") - var bitmap = BitmapFactory.decodeStream(imgInputStream) - bitmap = Bitmap.createScaledBitmap(bitmap!!, 224, 224, true) - imgInputStream.close() + val imgInputStream = javaClass.getResourceAsStream("/banana.jpeg") + var bitmap = BitmapFactory.decodeStream(imgInputStream) + bitmap = Bitmap.createScaledBitmap(bitmap!!, 224, 224, true) + imgInputStream.close() - val inputTensor = - bitmapToFloat32Tensor( - bitmap, - TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, - TensorImageUtils.TORCHVISION_NORM_STD_RGB, - ) + val inputTensor = + bitmapToFloat32Tensor( + bitmap, + TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, + TensorImageUtils.TORCHVISION_NORM_STD_RGB, + ) - val module = Module.load(getTestFilePath(filePath)) + val module = Module.load(getTestFilePath(filePath)) - val results = module.forward(EValue.from(inputTensor)) - Assert.assertTrue(results[0].isTensor) - val scores = results[0].toTensor().dataAsFloatArray + val results = module.forward(EValue.from(inputTensor)) + Assert.assertTrue(results[0].isTensor) + val scores = results[0].toTensor().dataAsFloatArray - val bananaClass = 954 // From ImageNet 1K - Assert.assertEquals(bananaClass.toLong(), argmax(scores).toLong()) - } + val bananaClass = 954 // From ImageNet 1K + Assert.assertEquals(bananaClass.toLong(), argmax(scores).toLong()) + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testXnnpackBackendRequired() { - val pteFile = File(getTestFilePath("/mv3_xnnpack_fp32.pte")) - val inputStream = javaClass.getResourceAsStream("/mv3_xnnpack_fp32.pte") - FileUtils.copyInputStreamToFile(inputStream, pteFile) - inputStream.close() + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testXnnpackBackendRequired() { + val pteFile = File(getTestFilePath("/mv3_xnnpack_fp32.pte")) + val inputStream = javaClass.getResourceAsStream("/mv3_xnnpack_fp32.pte") + FileUtils.copyInputStreamToFile(inputStream, pteFile) + inputStream.close() - val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) - val expectedBackends = arrayOf("XnnpackBackend") - assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").backends) - } + val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) + val expectedBackends = arrayOf("XnnpackBackend") + assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").backends) + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMv2Fp32() { - testClassification("/mv2_xnnpack_fp32.pte") - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMv2Fp32() { + testClassification("/mv2_xnnpack_fp32.pte") + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMv3Fp32() { - testClassification("/mv3_xnnpack_fp32.pte") - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMv3Fp32() { + testClassification("/mv3_xnnpack_fp32.pte") + } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testResnet50() { - testClassification("/resnet50_xnnpack_q8.pte") - } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testResnet50() { + testClassification("/resnet50_xnnpack_q8.pte") + } - companion object { + companion object { - fun argmax(array: FloatArray): Int { - require(array.isNotEmpty()) { "Array cannot be empty" } - var maxIndex = 0 - var maxValue = array[0] - for (i in 1 until array.size) { - if (array[i] > maxValue) { - maxValue = array[i] - maxIndex = i - } - } - return maxIndex + fun argmax(array: FloatArray): Int { + require(array.isNotEmpty()) { "Array cannot be empty" } + var maxIndex = 0 + var maxValue = array[0] + for (i in 1 until array.size) { + if (array[i] > maxValue) { + maxValue = array[i] + maxIndex = i } + } + return maxIndex } + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index 58e9cc8bfef..698358ba6a7 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -8,7 +8,6 @@ package org.pytorch.executorch import android.Manifest -import androidx.test.InstrumentationRegistry import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule import java.io.File @@ -28,151 +27,151 @@ import org.pytorch.executorch.TestFileUtils.getTestFilePath /** Unit tests for [Module]. */ @RunWith(AndroidJUnit4::class) class ModuleInstrumentationTest { - @Before - @Throws(IOException::class) - fun setUp() { - // copy zipped test resources to local device - val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) - var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, addPteFile) - inputStream.close() - - val nonPteFile = File(getTestFilePath(NON_PTE_FILE_NAME)) - inputStream = javaClass.getResourceAsStream(NON_PTE_FILE_NAME) - FileUtils.copyInputStreamToFile(inputStream, nonPteFile) - inputStream.close() - } - - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testModuleLoadAndForward() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - + @Before + @Throws(IOException::class) + fun setUp() { + // copy zipped test resources to local device + val addPteFile = File(getTestFilePath(TEST_FILE_NAME)) + var inputStream = javaClass.getResourceAsStream(TEST_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, addPteFile) + inputStream.close() + + val nonPteFile = File(getTestFilePath(NON_PTE_FILE_NAME)) + inputStream = javaClass.getResourceAsStream(NON_PTE_FILE_NAME) + FileUtils.copyInputStreamToFile(inputStream, nonPteFile) + inputStream.close() + } + + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testModuleLoadAndForward() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) + } + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMethodMetadata() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadMethodAndForward() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + + val results = module.forward() + Assert.assertTrue(results[0].isTensor) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadForwardExplicit() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val results = module.execute(FORWARD_METHOD) + Assert.assertTrue(results[0].isTensor) + } + + @Test(expected = RuntimeException::class) + @Throws(IOException::class) + fun testModuleLoadNonExistantFile() { + val module = Module.load(getTestFilePath(MISSING_FILE_NAME)) + } + + @Test + @Throws(IOException::class) + fun testModuleLoadMethodNonExistantMethod() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(NONE_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + } + + @Test(expected = RuntimeException::class) + @Throws(IOException::class) + fun testNonPteFile() { + val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + } + + @Test + @Throws(IOException::class) + fun testLoadOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + module.destroy() + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) + } + + @Test + @Throws(IOException::class) + fun testForwardOnDestroyedModule() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val loadMethod = module.loadMethod(FORWARD_METHOD) + Assert.assertEquals(loadMethod.toLong(), OK.toLong()) + + module.destroy() + + val results = module.forward() + Assert.assertEquals(0, results.size.toLong()) + } + + @Test + @Throws(InterruptedException::class, IOException::class) + fun testForwardFromMultipleThreads() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + val numThreads = 100 + val latch = CountDownLatch(numThreads) + val completed = AtomicInteger(0) + + val runnable = Runnable { + try { + latch.countDown() + latch.await(5000, TimeUnit.MILLISECONDS) val results = module.forward() Assert.assertTrue(results[0].isTensor) + completed.incrementAndGet() + } catch (_: InterruptedException) {} } - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testMethodMetadata() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + val threads = arrayOfNulls(numThreads) + for (i in 0 until numThreads) { + threads[i] = Thread(runnable) + threads[i]!!.start() } - @Test - @Throws(IOException::class) - fun testModuleLoadMethodAndForward() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) - - val results = module.forward() - Assert.assertTrue(results[0].isTensor) - } - - @Test - @Throws(IOException::class) - fun testModuleLoadForwardExplicit() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val results = module.execute(FORWARD_METHOD) - Assert.assertTrue(results[0].isTensor) - } - - @Test(expected = RuntimeException::class) - @Throws(IOException::class) - fun testModuleLoadNonExistantFile() { - val module = Module.load(getTestFilePath(MISSING_FILE_NAME)) - } - - @Test - @Throws(IOException::class) - fun testModuleLoadMethodNonExistantMethod() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(NONE_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) + for (i in 0 until numThreads) { + threads[i]!!.join() } - @Test(expected = RuntimeException::class) - @Throws(IOException::class) - fun testNonPteFile() { - val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) - } - - @Test - @Throws(IOException::class) - fun testLoadOnDestroyedModule() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - module.destroy() - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), INVALID_STATE.toLong()) - } - - @Test - @Throws(IOException::class) - fun testForwardOnDestroyedModule() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val loadMethod = module.loadMethod(FORWARD_METHOD) - Assert.assertEquals(loadMethod.toLong(), OK.toLong()) - - module.destroy() - - val results = module.forward() - Assert.assertEquals(0, results.size.toLong()) - } - - @Test - @Throws(InterruptedException::class, IOException::class) - fun testForwardFromMultipleThreads() { - val module = Module.load(getTestFilePath(TEST_FILE_NAME)) - - val numThreads = 100 - val latch = CountDownLatch(numThreads) - val completed = AtomicInteger(0) - - val runnable = Runnable { - try { - latch.countDown() - latch.await(5000, TimeUnit.MILLISECONDS) - val results = module.forward() - Assert.assertTrue(results[0].isTensor) - completed.incrementAndGet() - } catch (_: InterruptedException) {} - } - - val threads = arrayOfNulls(numThreads) - for (i in 0 until numThreads) { - threads[i] = Thread(runnable) - threads[i]!!.start() - } - - for (i in 0 until numThreads) { - threads[i]!!.join() - } - - Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) - } - - companion object { - private const val TEST_FILE_NAME = "/ModuleAdd.pte" - private const val MISSING_FILE_NAME = "/missing.pte" - private const val NON_PTE_FILE_NAME = "/test.txt" - private const val FORWARD_METHOD = "forward" - private const val NONE_METHOD = "none" - private const val OK = 0x00 - private const val INVALID_STATE = 0x2 - private const val INVALID_ARGUMENT = 0x12 - private const val ACCESS_FAILED = 0x22 - } + Assert.assertEquals(numThreads.toLong(), completed.get().toLong()) + } + + companion object { + private const val TEST_FILE_NAME = "/ModuleAdd.pte" + private const val MISSING_FILE_NAME = "/missing.pte" + private const val NON_PTE_FILE_NAME = "/test.txt" + private const val FORWARD_METHOD = "forward" + private const val NONE_METHOD = "none" + private const val OK = 0x00 + private const val INVALID_STATE = 0x2 + private const val INVALID_ARGUMENT = 0x12 + private const val ACCESS_FAILED = 0x22 + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.kt index 1bb4ef21d52..72828bc4535 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/RuntimeInstrumentationTest.kt @@ -12,23 +12,23 @@ import org.junit.Assert import org.junit.Test import org.junit.runner.RunWith -/** Unit tests for [ExecuTorchRuntime]. */ +/** Unit tests for [ExecuTorchRuntime]. */ @RunWith(AndroidJUnit4::class) class RuntimeInstrumentationTest { - @Test - fun testRuntimeApi() { - val ops = ExecuTorchRuntime.getRegisteredOps() - val backends = ExecuTorchRuntime.getRegisteredBackends() + @Test + fun testRuntimeApi() { + val ops = ExecuTorchRuntime.getRegisteredOps() + val backends = ExecuTorchRuntime.getRegisteredBackends() - Assert.assertNotNull(ops) - Assert.assertNotNull(backends) + Assert.assertNotNull(ops) + Assert.assertNotNull(backends) - for (op in ops) { - Assert.assertNotNull(op) - } + for (op in ops) { + Assert.assertNotNull(op) + } - for (backend in backends) { - Assert.assertNotNull(backend) - } + for (backend in backends) { + Assert.assertNotNull(backend) } + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt index cb2e365a4c5..de5240bc343 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TensorImageUtils.kt @@ -16,147 +16,147 @@ import java.nio.FloatBuffer * [android.media.Image] source. */ object TensorImageUtils { - @JvmField var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) + @JvmField var TORCHVISION_NORM_MEAN_RGB: FloatArray = floatArrayOf(0.485f, 0.456f, 0.406f) - @JvmField var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) + @JvmField var TORCHVISION_NORM_STD_RGB: FloatArray = floatArrayOf(0.229f, 0.224f, 0.225f) - /** - * Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in - * parameters mean and std. - * - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - @JvmStatic - fun bitmapToFloat32Tensor( - bitmap: Bitmap, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - ): Tensor { - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) + /** + * Creates new [Tensor] from full [android.graphics.Bitmap], normalized with specified in + * parameters mean and std. + * + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + @JvmStatic + fun bitmapToFloat32Tensor( + bitmap: Bitmap, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + ): Tensor { + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) - return bitmapToFloat32Tensor( - bitmap, - 0, - 0, - bitmap.width, - bitmap.height, - normMeanRGB, - normStdRGB, - ) - } + return bitmapToFloat32Tensor( + bitmap, + 0, + 0, + bitmap.width, + bitmap.height, + normMeanRGB, + normStdRGB, + ) + } - /** - * Writes tensor content from specified [android.graphics.Bitmap], normalized with specified in - * parameters mean and std to specified [java.nio.FloatBuffer] with specified offset. - * - * @param bitmap [android.graphics.Bitmap] as a source for Tensor data - * @param x - x coordinate of top left corner of bitmap's area - * @param y - y coordinate of top left corner of bitmap's area - * @param width - width of bitmap's area - * @param height - height of bitmap's area - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - fun bitmapToFloatBuffer( - bitmap: Bitmap, - x: Int, - y: Int, - width: Int, - height: Int, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - outBuffer: FloatBuffer, - outBufferOffset: Int, - ) { - checkOutBufferCapacity(outBuffer, outBufferOffset, width, height) - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) - val pixelsCount = height * width - val pixels = IntArray(pixelsCount) - bitmap.getPixels(pixels, 0, width, x, y, width, height) - val offsetB = 2 * pixelsCount - for (i in 0..99) { - val c = pixels[i] - Log.i("Image", ": " + i + " " + ((c shr 16) and 0xff)) - } - for (i in 0 until pixelsCount) { - val c = pixels[i] - val r = ((c shr 16) and 0xff) / 255.0f - val g = ((c shr 8) and 0xff) / 255.0f - val b = ((c) and 0xff) / 255.0f - outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]) - outBuffer.put(outBufferOffset + pixelsCount + i, (g - normMeanRGB[1]) / normStdRGB[1]) - outBuffer.put(outBufferOffset + offsetB + i, (b - normMeanRGB[2]) / normStdRGB[2]) - } + /** + * Writes tensor content from specified [android.graphics.Bitmap], normalized with specified in + * parameters mean and std to specified [java.nio.FloatBuffer] with specified offset. + * + * @param bitmap [android.graphics.Bitmap] as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + fun bitmapToFloatBuffer( + bitmap: Bitmap, + x: Int, + y: Int, + width: Int, + height: Int, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + outBuffer: FloatBuffer, + outBufferOffset: Int, + ) { + checkOutBufferCapacity(outBuffer, outBufferOffset, width, height) + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) + val pixelsCount = height * width + val pixels = IntArray(pixelsCount) + bitmap.getPixels(pixels, 0, width, x, y, width, height) + val offsetB = 2 * pixelsCount + for (i in 0..99) { + val c = pixels[i] + Log.i("Image", ": " + i + " " + ((c shr 16) and 0xff)) } + for (i in 0 until pixelsCount) { + val c = pixels[i] + val r = ((c shr 16) and 0xff) / 255.0f + val g = ((c shr 8) and 0xff) / 255.0f + val b = ((c) and 0xff) / 255.0f + outBuffer.put(outBufferOffset + i, (r - normMeanRGB[0]) / normStdRGB[0]) + outBuffer.put(outBufferOffset + pixelsCount + i, (g - normMeanRGB[1]) / normStdRGB[1]) + outBuffer.put(outBufferOffset + offsetB + i, (b - normMeanRGB[2]) / normStdRGB[2]) + } + } - /** - * Creates new [Tensor] from specified area of [android.graphics.Bitmap], normalized with - * specified in parameters mean and std. - * - * @param bitmap [android.graphics.Bitmap] as a source for Tensor data - * @param x - x coordinate of top left corner of bitmap's area - * @param y - y coordinate of top left corner of bitmap's area - * @param width - width of bitmap's area - * @param height - height of bitmap's area - * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order - * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB - * order - */ - fun bitmapToFloat32Tensor( - bitmap: Bitmap, - x: Int, - y: Int, - width: Int, - height: Int, - normMeanRGB: FloatArray, - normStdRGB: FloatArray, - ): Tensor { - checkNormMeanArg(normMeanRGB) - checkNormStdArg(normStdRGB) + /** + * Creates new [Tensor] from specified area of [android.graphics.Bitmap], normalized with + * specified in parameters mean and std. + * + * @param bitmap [android.graphics.Bitmap] as a source for Tensor data + * @param x - x coordinate of top left corner of bitmap's area + * @param y - y coordinate of top left corner of bitmap's area + * @param width - width of bitmap's area + * @param height - height of bitmap's area + * @param normMeanRGB means for RGB channels normalization, length must equal 3, RGB order + * @param normStdRGB standard deviation for RGB channels normalization, length must equal 3, RGB + * order + */ + fun bitmapToFloat32Tensor( + bitmap: Bitmap, + x: Int, + y: Int, + width: Int, + height: Int, + normMeanRGB: FloatArray, + normStdRGB: FloatArray, + ): Tensor { + checkNormMeanArg(normMeanRGB) + checkNormStdArg(normStdRGB) - val floatBuffer = Tensor.allocateFloatBuffer(3 * width * height) - bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0) - return Tensor.fromBlob(floatBuffer, longArrayOf(1, 3, height.toLong(), width.toLong())) - } + val floatBuffer = Tensor.allocateFloatBuffer(3 * width * height) + bitmapToFloatBuffer(bitmap, x, y, width, height, normMeanRGB, normStdRGB, floatBuffer, 0) + return Tensor.fromBlob(floatBuffer, longArrayOf(1, 3, height.toLong(), width.toLong())) + } - private fun checkOutBufferCapacity( - outBuffer: FloatBuffer, - outBufferOffset: Int, - tensorWidth: Int, - tensorHeight: Int, - ) { - check(outBufferOffset + 3 * tensorWidth * tensorHeight <= outBuffer.capacity()) { - "Buffer underflow" - } + private fun checkOutBufferCapacity( + outBuffer: FloatBuffer, + outBufferOffset: Int, + tensorWidth: Int, + tensorHeight: Int, + ) { + check(outBufferOffset + 3 * tensorWidth * tensorHeight <= outBuffer.capacity()) { + "Buffer underflow" } + } - private fun checkTensorSize(tensorWidth: Int, tensorHeight: Int) { - require(!(tensorHeight <= 0 || tensorWidth <= 0)) { - "tensorHeight and tensorWidth must be positive" - } + private fun checkTensorSize(tensorWidth: Int, tensorHeight: Int) { + require(!(tensorHeight <= 0 || tensorWidth <= 0)) { + "tensorHeight and tensorWidth must be positive" } + } - private fun checkRotateCWDegrees(rotateCWDegrees: Int) { - require( - !(rotateCWDegrees != 0 && - rotateCWDegrees != 90 && - rotateCWDegrees != 180 && - rotateCWDegrees != 270) - ) { - "rotateCWDegrees must be one of 0, 90, 180, 270" - } + private fun checkRotateCWDegrees(rotateCWDegrees: Int) { + require( + !(rotateCWDegrees != 0 && + rotateCWDegrees != 90 && + rotateCWDegrees != 180 && + rotateCWDegrees != 270) + ) { + "rotateCWDegrees must be one of 0, 90, 180, 270" } + } - private fun checkNormStdArg(normStdRGB: FloatArray) { - require(normStdRGB.size == 3) { "normStdRGB length must be 3" } - } + private fun checkNormStdArg(normStdRGB: FloatArray) { + require(normStdRGB.size == 3) { "normStdRGB length must be 3" } + } - private fun checkNormMeanArg(normMeanRGB: FloatArray) { - require(normMeanRGB.size == 3) { "normMeanRGB length must be 3" } - } + private fun checkNormMeanArg(normMeanRGB: FloatArray) { + require(normMeanRGB.size == 3) { "normMeanRGB length must be 3" } + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TestFileUtils.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TestFileUtils.kt index efa364f8e94..12fee88c5a1 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TestFileUtils.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/TestFileUtils.kt @@ -2,15 +2,11 @@ package org.pytorch.executorch import androidx.test.InstrumentationRegistry -/** - * Test File Utils - */ +/** Test File Utils */ object TestFileUtils { - fun getTestFilePath(fileName: String): String { - return InstrumentationRegistry.getInstrumentation() - .targetContext - .externalCacheDir - .toString() + fileName - } + fun getTestFilePath(fileName: String): String { + return InstrumentationRegistry.getInstrumentation().targetContext.externalCacheDir.toString() + + fileName + } } diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt index d71cc6aaedd..778c8fa46f2 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/training/TrainingModuleE2ETest.kt @@ -12,6 +12,11 @@ import android.Manifest import android.util.Log import androidx.test.ext.junit.runners.AndroidJUnit4 import androidx.test.rule.GrantPermissionRule +import java.io.File +import java.io.IOException +import java.net.URISyntaxException +import kotlin.random.Random +import kotlin.test.assertContains import org.apache.commons.io.FileUtils import org.junit.Assert import org.junit.Rule @@ -20,49 +25,46 @@ import org.junit.runner.RunWith import org.pytorch.executorch.EValue import org.pytorch.executorch.Tensor import org.pytorch.executorch.TestFileUtils -import java.io.File -import java.io.IOException -import java.net.URISyntaxException -import kotlin.random.Random -import kotlin.test.assertContains /** Unit tests for [TrainingModule]. */ @RunWith(AndroidJUnit4::class) class TrainingModuleE2ETest { - @get:Rule - var runtimePermissionRule: GrantPermissionRule = - GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) - - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testTrainXOR() { - val pteFilePath = "/xor.pte" - val ptdFilePath = "/xor.ptd" - - val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) - val pteInputStream = javaClass.getResourceAsStream(pteFilePath) - FileUtils.copyInputStreamToFile(pteInputStream, pteFile) - pteInputStream.close() - - val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath)) - val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath) - FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile) - ptdInputStream.close() - - val module = TrainingModule.load( + @get:Rule + var runtimePermissionRule: GrantPermissionRule = + GrantPermissionRule.grant(Manifest.permission.READ_EXTERNAL_STORAGE) + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testTrainXOR() { + val pteFilePath = "/xor.pte" + val ptdFilePath = "/xor.ptd" + + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) + val pteInputStream = javaClass.getResourceAsStream(pteFilePath) + FileUtils.copyInputStreamToFile(pteInputStream, pteFile) + pteInputStream.close() + + val ptdFile = File(TestFileUtils.getTestFilePath(ptdFilePath)) + val ptdInputStream = javaClass.getResourceAsStream(ptdFilePath) + FileUtils.copyInputStreamToFile(ptdInputStream, ptdFile) + ptdInputStream.close() + + val module = + TrainingModule.load( TestFileUtils.getTestFilePath(pteFilePath), - TestFileUtils.getTestFilePath(ptdFilePath) + TestFileUtils.getTestFilePath(ptdFilePath), ) - val params = module.namedParameters("forward") + val params = module.namedParameters("forward") - Assert.assertEquals(4, params.size) - assertContains(params, LIN_WEIGHT) - assertContains(params, LIN_BIAS) - assertContains(params, LIN2_WEIGHT) - assertContains(params, LIN2_BIAS) + Assert.assertEquals(4, params.size) + assertContains(params, LIN_WEIGHT) + assertContains(params, LIN_BIAS) + assertContains(params, LIN2_WEIGHT) + assertContains(params, LIN2_BIAS) - val sgd = SGD.create(params, 0.5); - val dataset = listOf( + val sgd = SGD.create(params, 0.5) + val dataset = + listOf( Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)), Tensor.fromBlob(longArrayOf(0), longArrayOf(1)), Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)), @@ -73,73 +75,71 @@ class TrainingModuleE2ETest { Tensor.fromBlob(longArrayOf(1), longArrayOf(1)), ) - val numEpochs = 5000; - var finalLoss = Float.MAX_VALUE - - for (i in 0 until numEpochs) { - val inputDex = 2 * Random.nextInt(dataset.size / 2) - val targetDex = inputDex + 1 - val input = dataset.get(inputDex) - val target = dataset.get(targetDex) - val out = module.executeForwardBackward("forward", - EValue.from(input), - EValue.from(target) - ) - val gradients = module.namedGradients("forward") - - if (i == 0) { - Assert.assertEquals(4, gradients.size) - assertContains(gradients, LIN_WEIGHT) - assertContains(gradients, LIN_BIAS) - assertContains(gradients, LIN2_WEIGHT) - assertContains(gradients, LIN2_BIAS) - } - - if (i % 500 == 0 || i == numEpochs - 1) { - Log.i( - "testTrainXOR", - String.format( - "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d", - i, - out[0].toTensor().getDataAsFloatArray()[0], - input.getDataAsFloatArray()[0], - input.getDataAsFloatArray()[1], - out[1].toTensor().getDataAsLongArray()[0], - target.getDataAsLongArray()[0] - ) - ); - } - - sgd.step(gradients) - - if (i == numEpochs - 1) { - finalLoss = out[0].toTensor().dataAsFloatArray[0] - } - } - Assert.assertTrue(finalLoss < 0.1f) - } - - @Test - @Throws(IOException::class, URISyntaxException::class) - fun testTrainXOR_PTEOnly() { - val pteFilePath = "/xor_full.pte" - - val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) - val pteInputStream = javaClass.getResourceAsStream(pteFilePath) - FileUtils.copyInputStreamToFile(pteInputStream, pteFile) - pteInputStream.close() - - val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath)); - val params = module.namedParameters("forward") + val numEpochs = 5000 + var finalLoss = Float.MAX_VALUE + + for (i in 0 until numEpochs) { + val inputDex = 2 * Random.nextInt(dataset.size / 2) + val targetDex = inputDex + 1 + val input = dataset.get(inputDex) + val target = dataset.get(targetDex) + val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target)) + val gradients = module.namedGradients("forward") + + if (i == 0) { + Assert.assertEquals(4, gradients.size) + assertContains(gradients, LIN_WEIGHT) + assertContains(gradients, LIN_BIAS) + assertContains(gradients, LIN2_WEIGHT) + assertContains(gradients, LIN2_BIAS) + } + + if (i % 500 == 0 || i == numEpochs - 1) { + Log.i( + "testTrainXOR", + String.format( + "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d", + i, + out[0].toTensor().getDataAsFloatArray()[0], + input.getDataAsFloatArray()[0], + input.getDataAsFloatArray()[1], + out[1].toTensor().getDataAsLongArray()[0], + target.getDataAsLongArray()[0], + ), + ) + } - Assert.assertEquals(4, params.size) - assertContains(params, LIN_WEIGHT) - assertContains(params, LIN_BIAS) - assertContains(params, LIN2_WEIGHT) - assertContains(params, LIN2_BIAS) + sgd.step(gradients) - val sgd = SGD.create(params, 0.5); - val dataset = listOf( + if (i == numEpochs - 1) { + finalLoss = out[0].toTensor().dataAsFloatArray[0] + } + } + Assert.assertTrue(finalLoss < 0.1f) + } + + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testTrainXOR_PTEOnly() { + val pteFilePath = "/xor_full.pte" + + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) + val pteInputStream = javaClass.getResourceAsStream(pteFilePath) + FileUtils.copyInputStreamToFile(pteInputStream, pteFile) + pteInputStream.close() + + val module = TrainingModule.load(TestFileUtils.getTestFilePath(pteFilePath)) + val params = module.namedParameters("forward") + + Assert.assertEquals(4, params.size) + assertContains(params, LIN_WEIGHT) + assertContains(params, LIN_BIAS) + assertContains(params, LIN2_WEIGHT) + assertContains(params, LIN2_BIAS) + + val sgd = SGD.create(params, 0.5) + val dataset = + listOf( Tensor.fromBlob(floatArrayOf(1.0f, 1.0f), longArrayOf(1, 2)), Tensor.fromBlob(longArrayOf(0), longArrayOf(1)), Tensor.fromBlob(floatArrayOf(0.0f, 0.0f), longArrayOf(1, 2)), @@ -150,91 +150,90 @@ class TrainingModuleE2ETest { Tensor.fromBlob(longArrayOf(1), longArrayOf(1)), ) - val numEpochs = 5000; - var finalLoss = Float.MAX_VALUE - - for (i in 0 until numEpochs) { - val inputDex = 2 * Random.nextInt(dataset.size / 2) - val targetDex = inputDex + 1 - val input = dataset.get(inputDex) - val target = dataset.get(targetDex) - val out = module.executeForwardBackward("forward", - EValue.from(input), - EValue.from(target) - ) - val gradients = module.namedGradients("forward") - - if (i == 0) { - Assert.assertEquals(4, gradients.size) - assertContains(gradients, LIN_WEIGHT) - assertContains(gradients, LIN_BIAS) - assertContains(gradients, LIN2_WEIGHT) - assertContains(gradients, LIN2_BIAS) - } - - if (i % 500 == 0 || i == numEpochs - 1) { - Log.i( - "testTrainXOR_PTEOnly", - String.format( - "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d", - i, - out[0].toTensor().getDataAsFloatArray()[0], - input.getDataAsFloatArray()[0], - input.getDataAsFloatArray()[1], - out[1].toTensor().getDataAsLongArray()[0], - target.getDataAsLongArray()[0] - ) - ); - } - - sgd.step(gradients) - - if (i == numEpochs - 1) { - finalLoss = out[0].toTensor().dataAsFloatArray[0] - } - } - Assert.assertTrue(finalLoss < 0.1f) - } - - @Test - @Throws(IOException::class) - fun testMissingPteFile() { - val exception = Assert.assertThrows(RuntimeException::class.java) { - TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME)) - } - Assert.assertEquals( - exception.message, - "Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME) + val numEpochs = 5000 + var finalLoss = Float.MAX_VALUE + + for (i in 0 until numEpochs) { + val inputDex = 2 * Random.nextInt(dataset.size / 2) + val targetDex = inputDex + 1 + val input = dataset.get(inputDex) + val target = dataset.get(targetDex) + val out = module.executeForwardBackward("forward", EValue.from(input), EValue.from(target)) + val gradients = module.namedGradients("forward") + + if (i == 0) { + Assert.assertEquals(4, gradients.size) + assertContains(gradients, LIN_WEIGHT) + assertContains(gradients, LIN_BIAS) + assertContains(gradients, LIN2_WEIGHT) + assertContains(gradients, LIN2_BIAS) + } + + if (i % 500 == 0 || i == numEpochs - 1) { + Log.i( + "testTrainXOR_PTEOnly", + String.format( + "Step %d, Loss %f, Input [%.0f, %.0f], Prediction %d, Label %d", + i, + out[0].toTensor().getDataAsFloatArray()[0], + input.getDataAsFloatArray()[0], + input.getDataAsFloatArray()[1], + out[1].toTensor().getDataAsLongArray()[0], + target.getDataAsLongArray()[0], + ), ) - } + } - @Test - @Throws(IOException::class) - fun testMissingPtdFile() { - val exception = Assert.assertThrows(RuntimeException::class.java) { - val pteFilePath = "/xor.pte" - val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) - val pteInputStream = javaClass.getResourceAsStream(pteFilePath) - FileUtils.copyInputStreamToFile(pteInputStream, pteFile) - pteInputStream.close() - - TrainingModule.load( - TestFileUtils.getTestFilePath(pteFilePath), - TestFileUtils.getTestFilePath(MISSING_PTD_NAME) - ) - } - Assert.assertEquals( - exception.message, - "Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME) - ) - } + sgd.step(gradients) - companion object { - private const val LIN_WEIGHT = "net.linear.weight" - private const val LIN_BIAS = "net.linear.bias" - private const val LIN2_WEIGHT = "net.linear2.weight" - private const val LIN2_BIAS = "net.linear2.bias" - private const val MISSING_PTE_NAME = "/missing.pte" - private const val MISSING_PTD_NAME = "/missing.ptd" + if (i == numEpochs - 1) { + finalLoss = out[0].toTensor().dataAsFloatArray[0] + } } -} \ No newline at end of file + Assert.assertTrue(finalLoss < 0.1f) + } + + @Test + @Throws(IOException::class) + fun testMissingPteFile() { + val exception = + Assert.assertThrows(RuntimeException::class.java) { + TrainingModule.load(TestFileUtils.getTestFilePath(MISSING_PTE_NAME)) + } + Assert.assertEquals( + exception.message, + "Cannot load model path!! " + TestFileUtils.getTestFilePath(MISSING_PTE_NAME), + ) + } + + @Test + @Throws(IOException::class) + fun testMissingPtdFile() { + val exception = + Assert.assertThrows(RuntimeException::class.java) { + val pteFilePath = "/xor.pte" + val pteFile = File(TestFileUtils.getTestFilePath(pteFilePath)) + val pteInputStream = javaClass.getResourceAsStream(pteFilePath) + FileUtils.copyInputStreamToFile(pteInputStream, pteFile) + pteInputStream.close() + + TrainingModule.load( + TestFileUtils.getTestFilePath(pteFilePath), + TestFileUtils.getTestFilePath(MISSING_PTD_NAME), + ) + } + Assert.assertEquals( + exception.message, + "Cannot load data path!! " + TestFileUtils.getTestFilePath(MISSING_PTD_NAME), + ) + } + + companion object { + private const val LIN_WEIGHT = "net.linear.weight" + private const val LIN_BIAS = "net.linear.bias" + private const val LIN2_WEIGHT = "net.linear2.weight" + private const val LIN2_BIAS = "net.linear2.bias" + private const val MISSING_PTE_NAME = "/missing.pte" + private const val MISSING_PTD_NAME = "/missing.ptd" + } +} diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt index b39b9be9c3b..7e9fea9a699 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt @@ -18,200 +18,200 @@ import org.junit.runners.JUnit4 /** Unit tests for [EValue]. */ @RunWith(JUnit4::class) class EValueTest { - @Test - fun testNone() { - val evalue = EValue.optionalNone() - assertTrue(evalue.isNone) + @Test + fun testNone() { + val evalue = EValue.optionalNone() + assertTrue(evalue.isNone) + } + + @Test + fun testTensorValue() { + val data = longArrayOf(1, 2, 3) + val shape = longArrayOf(1, 3) + val evalue = EValue.from(Tensor.fromBlob(data, shape)) + assertTrue(evalue.isTensor) + assertTrue(evalue.toTensor().shape.contentEquals(shape)) + assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) + } + + @Test + fun testBoolValue() { + val evalue = EValue.from(true) + assertTrue(evalue.isBool) + assertTrue(evalue.toBool()) + } + + @Test + fun testIntValue() { + val evalue = EValue.from(1) + assertTrue(evalue.isInt) + assertEquals(evalue.toInt(), 1) + } + + @Test + fun testDoubleValue() { + val evalue = EValue.from(0.1) + assertTrue(evalue.isDouble) + assertEquals(evalue.toDouble(), 0.1, 0.0001) + } + + @Test + fun testStringValue() { + val evalue = EValue.from("a") + assertTrue(evalue.isString) + assertEquals(evalue.toStr(), "a") + } + + @Test + fun testAllIllegalCast() { + val evalue = EValue.optionalNone() + assertTrue(evalue.isNone) + + // try Tensor + assertFalse(evalue.isTensor) + assertThatThrownBy { evalue.toTensor() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Tensor, actual type None") + + // try bool + assertFalse(evalue.isBool) + assertThatThrownBy { evalue.toBool() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Bool, actual type None") + + // try int + assertFalse(evalue.isInt) + assertThatThrownBy { evalue.toInt() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Int, actual type None") + + // try double + assertFalse(evalue.isDouble) + assertThatThrownBy { evalue.toDouble() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Double, actual type None") + + // try string + assertFalse(evalue.isString) + assertThatThrownBy { evalue.toStr() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type String, actual type None") + } + + @Test + fun testNoneSerde() { + val evalue = EValue.optionalNone() + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isNone, true) + } + + @Test + fun testBoolSerde() { + val evalue = EValue.from(true) + val bytes = evalue.toByteArray() + assertEquals(1, bytes[1].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isBool, true) + assertEquals(deser.toBool(), true) + } + + @Test + fun testBoolSerde2() { + val evalue = EValue.from(false) + val bytes = evalue.toByteArray() + assertEquals(0, bytes[1].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isBool, true) + assertEquals(deser.toBool(), false) + } + + @Test + fun testIntSerde() { + val evalue = EValue.from(1) + val bytes = evalue.toByteArray() + assertEquals(0, bytes[1].toLong()) + assertEquals(0, bytes[2].toLong()) + assertEquals(0, bytes[3].toLong()) + assertEquals(0, bytes[4].toLong()) + assertEquals(0, bytes[5].toLong()) + assertEquals(0, bytes[6].toLong()) + assertEquals(0, bytes[7].toLong()) + assertEquals(1, bytes[8].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isInt, true) + assertEquals(deser.toInt(), 1) + } + + @Test + fun testLargeIntSerde() { + val evalue = EValue.from(256000) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isInt, true) + assertEquals(deser.toInt(), 256000) + } + + @Test + fun testDoubleSerde() { + val evalue = EValue.from(1.345e-2) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isDouble, true) + assertEquals(1.345e-2, deser.toDouble(), 1e-6) + } + + @Test + fun testLongTensorSerde() { + val data = longArrayOf(1, 2, 3, 4) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + + val evalue = EValue.from(tensor) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isTensor, true) + val deserTensor = deser.toTensor() + val deserShape = deserTensor.shape() + val deserData = deserTensor.dataAsLongArray + + for (i in data.indices) { + assertEquals(data[i], deserData[i]) } - @Test - fun testTensorValue() { - val data = longArrayOf(1, 2, 3) - val shape = longArrayOf(1, 3) - val evalue = EValue.from(Tensor.fromBlob(data, shape)) - assertTrue(evalue.isTensor) - assertTrue(evalue.toTensor().shape.contentEquals(shape)) - assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } - @Test - fun testBoolValue() { - val evalue = EValue.from(true) - assertTrue(evalue.isBool) - assertTrue(evalue.toBool()) - } - - @Test - fun testIntValue() { - val evalue = EValue.from(1) - assertTrue(evalue.isInt) - assertEquals(evalue.toInt(), 1) - } - - @Test - fun testDoubleValue() { - val evalue = EValue.from(0.1) - assertTrue(evalue.isDouble) - assertEquals(evalue.toDouble(), 0.1, 0.0001) - } - - @Test - fun testStringValue() { - val evalue = EValue.from("a") - assertTrue(evalue.isString) - assertEquals(evalue.toStr(), "a") - } - - @Test - fun testAllIllegalCast() { - val evalue = EValue.optionalNone() - assertTrue(evalue.isNone) - - // try Tensor - assertFalse(evalue.isTensor) - assertThatThrownBy { evalue.toTensor() } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type Tensor, actual type None") - - // try bool - assertFalse(evalue.isBool) - assertThatThrownBy { evalue.toBool() } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type Bool, actual type None") - - // try int - assertFalse(evalue.isInt) - assertThatThrownBy { evalue.toInt() } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type Int, actual type None") - - // try double - assertFalse(evalue.isDouble) - assertThatThrownBy { evalue.toDouble() } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type Double, actual type None") - - // try string - assertFalse(evalue.isString) - assertThatThrownBy { evalue.toStr() } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type String, actual type None") - } - - @Test - fun testNoneSerde() { - val evalue = EValue.optionalNone() - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isNone, true) - } - - @Test - fun testBoolSerde() { - val evalue = EValue.from(true) - val bytes = evalue.toByteArray() - assertEquals(1, bytes[1].toLong()) - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isBool, true) - assertEquals(deser.toBool(), true) - } - - @Test - fun testBoolSerde2() { - val evalue = EValue.from(false) - val bytes = evalue.toByteArray() - assertEquals(0, bytes[1].toLong()) + @Test + fun testFloatTensorSerde() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isBool, true) - assertEquals(deser.toBool(), false) - } - - @Test - fun testIntSerde() { - val evalue = EValue.from(1) - val bytes = evalue.toByteArray() - assertEquals(0, bytes[1].toLong()) - assertEquals(0, bytes[2].toLong()) - assertEquals(0, bytes[3].toLong()) - assertEquals(0, bytes[4].toLong()) - assertEquals(0, bytes[5].toLong()) - assertEquals(0, bytes[6].toLong()) - assertEquals(0, bytes[7].toLong()) - assertEquals(1, bytes[8].toLong()) - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isInt, true) - assertEquals(deser.toInt(), 1) - } - - @Test - fun testLargeIntSerde() { - val evalue = EValue.from(256000) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isInt, true) - assertEquals(deser.toInt(), 256000) - } - - @Test - fun testDoubleSerde() { - val evalue = EValue.from(1.345e-2) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isDouble, true) - assertEquals(1.345e-2, deser.toDouble(), 1e-6) - } - - @Test - fun testLongTensorSerde() { - val data = longArrayOf(1, 2, 3, 4) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - - val evalue = EValue.from(tensor) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isTensor, true) - val deserTensor = deser.toTensor() - val deserShape = deserTensor.shape() - val deserData = deserTensor.dataAsLongArray + val evalue = EValue.from(tensor) + val bytes = evalue.toByteArray() - for (i in data.indices) { - assertEquals(data[i], deserData[i]) - } + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isTensor, true) + val deserTensor = deser.toTensor() + val deserShape = deserTensor.shape() + val deserData = deserTensor.dataAsFloatArray - for (i in shape.indices) { - assertEquals(shape[i], deserShape[i]) - } + for (i in data.indices) { + assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) } - @Test - fun testFloatTensorSerde() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - - val evalue = EValue.from(tensor) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isTensor, true) - val deserTensor = deser.toTensor() - val deserShape = deserTensor.shape() - val deserData = deserTensor.dataAsFloatArray - - for (i in data.indices) { - assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) - } - - for (i in shape.indices) { - assertEquals(shape[i], deserShape[i]) - } + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } } diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt index e676ebb3e75..e59b40030d7 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt @@ -16,262 +16,262 @@ import org.junit.runners.JUnit4 /** Unit tests for [Tensor]. */ @RunWith(JUnit4::class) class TensorTest { - @Test - fun testFloatTensor() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.FLOAT) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) - assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) - assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) - assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) - - val floatBuffer = Tensor.allocateFloatBuffer(4) - floatBuffer.put(data) - tensor = Tensor.fromBlob(floatBuffer, shape) - assertEquals(tensor.dtype(), DType.FLOAT) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) - assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) - assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) - assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + @Test + fun testFloatTensor() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + + val floatBuffer = Tensor.allocateFloatBuffer(4) + floatBuffer.put(data) + tensor = Tensor.fromBlob(floatBuffer, shape) + assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + } + + @Test + fun testIntTensor() { + val data = intArrayOf(Int.MIN_VALUE, 0, 1, Int.MAX_VALUE) + val shape = longArrayOf(1, 4, 1) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT32) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + + val intBuffer = Tensor.allocateIntBuffer(4) + intBuffer.put(data) + tensor = Tensor.fromBlob(intBuffer, shape) + assertEquals(tensor.dtype(), DType.INT32) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + } + + @Test + fun testDoubleTensor() { + val data = doubleArrayOf(Double.MIN_VALUE, 0.0, 0.1, Double.MAX_VALUE) + val shape = longArrayOf(1, 4) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.DOUBLE) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + + val doubleBuffer = Tensor.allocateDoubleBuffer(4) + doubleBuffer.put(data) + tensor = Tensor.fromBlob(doubleBuffer, shape) + assertEquals(tensor.dtype(), DType.DOUBLE) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + } + + @Test + fun testLongTensor() { + val data = longArrayOf(Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE) + val shape = longArrayOf(4, 1) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT64) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsLongArray[0]) + assertEquals(data[1], tensor.dataAsLongArray[1]) + assertEquals(data[2], tensor.dataAsLongArray[2]) + assertEquals(data[3], tensor.dataAsLongArray[3]) + + val longBuffer = Tensor.allocateLongBuffer(4) + longBuffer.put(data) + tensor = Tensor.fromBlob(longBuffer, shape) + assertEquals(tensor.dtype(), DType.INT64) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsLongArray[0]) + assertEquals(data[1], tensor.dataAsLongArray[1]) + assertEquals(data[2], tensor.dataAsLongArray[2]) + assertEquals(data[3], tensor.dataAsLongArray[3]) + } + + @Test + fun testSignedByteTensor() { + val data = byteArrayOf(Byte.MIN_VALUE, 0.toByte(), 1.toByte(), Byte.MAX_VALUE) + val shape = longArrayOf(1, 1, 4) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + + val byteBuffer = Tensor.allocateByteBuffer(4) + byteBuffer.put(data) + tensor = Tensor.fromBlob(byteBuffer, shape) + assertEquals(tensor.dtype(), DType.INT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + } + + @Test + fun testUnsignedByteTensor() { + val data = byteArrayOf(0.toByte(), 1.toByte(), 2.toByte(), 255.toByte()) + val shape = longArrayOf(4, 1, 1) + var tensor = Tensor.fromBlobUnsigned(data, shape) + assertEquals(tensor.dtype(), DType.UINT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + + val byteBuffer = Tensor.allocateByteBuffer(4) + byteBuffer.put(data) + tensor = Tensor.fromBlobUnsigned(byteBuffer, shape) + assertEquals(tensor.dtype(), DType.UINT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + } + + @Test + fun testIllegalDataTypeException() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.FLOAT) + + assertThatThrownBy { tensor.dataAsByteArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") + + assertThatThrownBy { tensor.dataAsUnsignedByteArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") + + assertThatThrownBy { tensor.dataAsIntArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") + + assertThatThrownBy { tensor.dataAsDoubleArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") + + assertThatThrownBy { tensor.dataAsLongArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") + } + + @Test + fun testIllegalArguments() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shapeWithNegativeValues = longArrayOf(-1, 2) + val mismatchShape = longArrayOf(1, 2) + + assertThatThrownBy { Tensor.fromBlob(null as FloatArray?, mismatchShape) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Data array must be not null") + + assertThatThrownBy { Tensor.fromBlob(data, null) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape must be not null") + + assertThatThrownBy { Tensor.fromBlob(data, shapeWithNegativeValues) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape elements must be non negative") + + assertThatThrownBy { Tensor.fromBlob(data, mismatchShape) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") + } + + @Test + fun testLongTensorSerde() { + val data = longArrayOf(1, 2, 3, 4) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + val bytes = tensor.toByteArray() + + val deser = Tensor.fromByteArray(bytes) + val deserShape = deser.shape() + val deserData = deser.dataAsLongArray + + for (i in data.indices) { + assertEquals(data[i], deserData[i]) } - @Test - fun testIntTensor() { - val data = intArrayOf(Int.MIN_VALUE, 0, 1, Int.MAX_VALUE) - val shape = longArrayOf(1, 4, 1) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.INT32) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) - - val intBuffer = Tensor.allocateIntBuffer(4) - intBuffer.put(data) - tensor = Tensor.fromBlob(intBuffer, shape) - assertEquals(tensor.dtype(), DType.INT32) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } - @Test - fun testDoubleTensor() { - val data = doubleArrayOf(Double.MIN_VALUE, 0.0, 0.1, Double.MAX_VALUE) - val shape = longArrayOf(1, 4) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.DOUBLE) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) - assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) - assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) - assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) - - val doubleBuffer = Tensor.allocateDoubleBuffer(4) - doubleBuffer.put(data) - tensor = Tensor.fromBlob(doubleBuffer, shape) - assertEquals(tensor.dtype(), DType.DOUBLE) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) - assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) - assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) - assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) - } - - @Test - fun testLongTensor() { - val data = longArrayOf(Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE) - val shape = longArrayOf(4, 1) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.INT64) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0], tensor.dataAsLongArray[0]) - assertEquals(data[1], tensor.dataAsLongArray[1]) - assertEquals(data[2], tensor.dataAsLongArray[2]) - assertEquals(data[3], tensor.dataAsLongArray[3]) - - val longBuffer = Tensor.allocateLongBuffer(4) - longBuffer.put(data) - tensor = Tensor.fromBlob(longBuffer, shape) - assertEquals(tensor.dtype(), DType.INT64) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0], tensor.dataAsLongArray[0]) - assertEquals(data[1], tensor.dataAsLongArray[1]) - assertEquals(data[2], tensor.dataAsLongArray[2]) - assertEquals(data[3], tensor.dataAsLongArray[3]) - } + @Test + fun testFloatTensorSerde() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + val bytes = tensor.toByteArray() - @Test - fun testSignedByteTensor() { - val data = byteArrayOf(Byte.MIN_VALUE, 0.toByte(), 1.toByte(), Byte.MAX_VALUE) - val shape = longArrayOf(1, 1, 4) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.INT8) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) - - val byteBuffer = Tensor.allocateByteBuffer(4) - byteBuffer.put(data) - tensor = Tensor.fromBlob(byteBuffer, shape) - assertEquals(tensor.dtype(), DType.INT8) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) - } + val deser = Tensor.fromByteArray(bytes) + val deserShape = deser.shape() + val deserData = deser.dataAsFloatArray - @Test - fun testUnsignedByteTensor() { - val data = byteArrayOf(0.toByte(), 1.toByte(), 2.toByte(), 255.toByte()) - val shape = longArrayOf(4, 1, 1) - var tensor = Tensor.fromBlobUnsigned(data, shape) - assertEquals(tensor.dtype(), DType.UINT8) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) - - val byteBuffer = Tensor.allocateByteBuffer(4) - byteBuffer.put(data) - tensor = Tensor.fromBlobUnsigned(byteBuffer, shape) - assertEquals(tensor.dtype(), DType.UINT8) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + for (i in data.indices) { + assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) } - @Test - fun testIllegalDataTypeException() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.FLOAT) - - assertThatThrownBy { tensor.dataAsByteArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") - - assertThatThrownBy { tensor.dataAsUnsignedByteArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") - - assertThatThrownBy { tensor.dataAsIntArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") - - assertThatThrownBy { tensor.dataAsDoubleArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") - - assertThatThrownBy { tensor.dataAsLongArray } - .isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") - } - - @Test - fun testIllegalArguments() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shapeWithNegativeValues = longArrayOf(-1, 2) - val mismatchShape = longArrayOf(1, 2) - - assertThatThrownBy { Tensor.fromBlob(null as FloatArray?, mismatchShape) } - .isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Data array must be not null") - - assertThatThrownBy { Tensor.fromBlob(data, null) } - .isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Shape must be not null") - - assertThatThrownBy { Tensor.fromBlob(data, shapeWithNegativeValues) } - .isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Shape elements must be non negative") - - assertThatThrownBy { Tensor.fromBlob(data, mismatchShape) } - .isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") - } - - @Test - fun testLongTensorSerde() { - val data = longArrayOf(1, 2, 3, 4) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - val bytes = tensor.toByteArray() - - val deser = Tensor.fromByteArray(bytes) - val deserShape = deser.shape() - val deserData = deser.dataAsLongArray - - for (i in data.indices) { - assertEquals(data[i], deserData[i]) - } - - for (i in shape.indices) { - assertEquals(shape[i], deserShape[i]) - } - } - - @Test - fun testFloatTensorSerde() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - val bytes = tensor.toByteArray() - - val deser = Tensor.fromByteArray(bytes) - val deserShape = deser.shape() - val deserData = deser.dataAsFloatArray - - for (i in data.indices) { - assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) - } - - for (i in shape.indices) { - assertEquals(shape[i], deserShape[i]) - } + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } }