Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extension/android/executorch_android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies {
implementation 'com.facebook.soloader:nativeloader:0.10.5'
implementation libs.core.ktx
testImplementation 'junit:junit:4.12'
testImplementation 'org.assertj:assertj-core:3.27.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
androidTestImplementation 'androidx.test:rules:1.2.0'
androidTestImplementation 'commons-io:commons-io:2.4'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ public byte[] getDataAsByteArray() {
*/
public byte[] getDataAsUnsignedByteArray() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
"Tensor of type " + getClass().getSimpleName() + " cannot return data as unsigned byte array.");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
*/
package org.pytorch.executorch

import org.junit.Assert
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
Expand All @@ -18,91 +21,76 @@ class EValueTest {
@Test
fun testNone() {
val evalue = EValue.optionalNone()
Assert.assertTrue(evalue.isNone)
assertTrue(evalue.isNone)
}

@Test
fun testTensorValue() {
val data = longArrayOf(1, 2, 3)
val shape = longArrayOf(1, 3)
val evalue = EValue.from(Tensor.fromBlob(data, shape))
Assert.assertTrue(evalue.isTensor)
Assert.assertTrue(evalue.toTensor().shape.contentEquals(shape))
Assert.assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data))
assertTrue(evalue.isTensor)
assertTrue(evalue.toTensor().shape.contentEquals(shape))
assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data))
}

@Test
fun testBoolValue() {
val evalue = EValue.from(true)
Assert.assertTrue(evalue.isBool)
Assert.assertTrue(evalue.toBool())
assertTrue(evalue.isBool)
assertTrue(evalue.toBool())
}

@Test
fun testIntValue() {
val evalue = EValue.from(1)
Assert.assertTrue(evalue.isInt)
Assert.assertEquals(evalue.toInt(), 1)
assertTrue(evalue.isInt)
assertEquals(evalue.toInt(), 1)
}

@Test
fun testDoubleValue() {
val evalue = EValue.from(0.1)
Assert.assertTrue(evalue.isDouble)
Assert.assertEquals(evalue.toDouble(), 0.1, 0.0001)
assertTrue(evalue.isDouble)
assertEquals(evalue.toDouble(), 0.1, 0.0001)
}

@Test
fun testStringValue() {
val evalue = EValue.from("a")
Assert.assertTrue(evalue.isString)
Assert.assertEquals(evalue.toStr(), "a")
assertTrue(evalue.isString)
assertEquals(evalue.toStr(), "a")
}

@Test
fun testAllIllegalCast() {
val evalue = EValue.optionalNone()
Assert.assertTrue(evalue.isNone)
assertTrue(evalue.isNone)

// try Tensor
Assert.assertFalse(evalue.isTensor)
try {
evalue.toTensor()
Assert.fail("Should have thrown an exception")
} catch (e: IllegalStateException) {
}
assertFalse(evalue.isTensor)
assertThatThrownBy {
evalue.toTensor() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Tensor, actual type None")

// try bool
Assert.assertFalse(evalue.isBool)
try {
evalue.toBool()
Assert.fail("Should have thrown an exception")
} catch (e: IllegalStateException) {
}
assertFalse(evalue.isBool)
assertThatThrownBy {
evalue.toBool() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Bool, actual type None")

// try int
Assert.assertFalse(evalue.isInt)
try {
evalue.toInt()
Assert.fail("Should have thrown an exception")
} catch (e: IllegalStateException) {
}
assertFalse(evalue.isInt)
assertThatThrownBy {
evalue.toInt() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Int, actual type None")

// try double
Assert.assertFalse(evalue.isDouble)
try {
evalue.toDouble()
Assert.fail("Should have thrown an exception")
} catch (e: IllegalStateException) {
}
assertFalse(evalue.isDouble)
assertThatThrownBy {
evalue.toDouble() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Double, actual type None")

// try string
Assert.assertFalse(evalue.isString)
try {
evalue.toStr()
Assert.fail("Should have thrown an exception")
} catch (e: IllegalStateException) {
}
assertFalse(evalue.isString)
assertThatThrownBy {
evalue.toStr() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type String, actual type None")
}

@Test
Expand All @@ -111,47 +99,47 @@ class EValueTest {
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
Assert.assertEquals(deser.isNone, true)
assertEquals(deser.isNone, true)
}

@Test
fun testBoolSerde() {
val evalue = EValue.from(true)
val bytes = evalue.toByteArray()
Assert.assertEquals(1, bytes[1].toLong())
assertEquals(1, bytes[1].toLong())

val deser = EValue.fromByteArray(bytes)
Assert.assertEquals(deser.isBool, true)
Assert.assertEquals(deser.toBool(), true)
assertEquals(deser.isBool, true)
assertEquals(deser.toBool(), true)
}

@Test
fun testBoolSerde2() {
val evalue = EValue.from(false)
val bytes = evalue.toByteArray()
Assert.assertEquals(0, bytes[1].toLong())
assertEquals(0, bytes[1].toLong())

val deser = EValue.fromByteArray(bytes)
Assert.assertEquals(deser.isBool, true)
Assert.assertEquals(deser.toBool(), false)
assertEquals(deser.isBool, true)
assertEquals(deser.toBool(), false)
}

@Test
fun testIntSerde() {
val evalue = EValue.from(1)
val bytes = evalue.toByteArray()
Assert.assertEquals(0, bytes[1].toLong())
Assert.assertEquals(0, bytes[2].toLong())
Assert.assertEquals(0, bytes[3].toLong())
Assert.assertEquals(0, bytes[4].toLong())
Assert.assertEquals(0, bytes[5].toLong())
Assert.assertEquals(0, bytes[6].toLong())
Assert.assertEquals(0, bytes[7].toLong())
Assert.assertEquals(1, bytes[8].toLong())
assertEquals(0, bytes[1].toLong())
assertEquals(0, bytes[2].toLong())
assertEquals(0, bytes[3].toLong())
assertEquals(0, bytes[4].toLong())
assertEquals(0, bytes[5].toLong())
assertEquals(0, bytes[6].toLong())
assertEquals(0, bytes[7].toLong())
assertEquals(1, bytes[8].toLong())

val deser = EValue.fromByteArray(bytes)
Assert.assertEquals(deser.isInt, true)
Assert.assertEquals(deser.toInt(), 1)
assertEquals(deser.isInt, true)
assertEquals(deser.toInt(), 1)
}

@Test
Expand All @@ -160,8 +148,8 @@ class EValueTest {
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
Assert.assertEquals(deser.isInt, true)
Assert.assertEquals(deser.toInt(), 256000)
assertEquals(deser.isInt, true)
assertEquals(deser.toInt(), 256000)
}

@Test
Expand All @@ -170,8 +158,8 @@ class EValueTest {
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
Assert.assertEquals(deser.isDouble, true)
Assert.assertEquals(1.345e-2, deser.toDouble(), 1e-6)
assertEquals(deser.isDouble, true)
assertEquals(1.345e-2, deser.toDouble(), 1e-6)
}

@Test
Expand All @@ -184,17 +172,17 @@ class EValueTest {
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
Assert.assertEquals(deser.isTensor, true)
assertEquals(deser.isTensor, true)
val deserTensor = deser.toTensor()
val deserShape = deserTensor.shape()
val deserData = deserTensor.dataAsLongArray

for (i in data.indices) {
Assert.assertEquals(data[i], deserData[i])
assertEquals(data[i], deserData[i])
}

for (i in shape.indices) {
Assert.assertEquals(shape[i], deserShape[i])
assertEquals(shape[i], deserShape[i])
}
}

Expand All @@ -208,17 +196,17 @@ class EValueTest {
val bytes = evalue.toByteArray()

val deser = EValue.fromByteArray(bytes)
Assert.assertEquals(deser.isTensor, true)
assertEquals(deser.isTensor, true)
val deserTensor = deser.toTensor()
val deserShape = deserTensor.shape()
val deserData = deserTensor.dataAsFloatArray

for (i in data.indices) {
Assert.assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5)
assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5)
}

for (i in shape.indices) {
Assert.assertEquals(shape[i], deserShape[i])
assertEquals(shape[i], deserShape[i])
}
}
}
Loading
Loading