Skip to content

Commit 6dec17b

Browse files
author
Haiting Pu
committed
Introduce assertj test lib to make the throw exception test more accurate
1 parent b1b46ee commit 6dec17b

File tree

4 files changed

+195
-234
lines changed

4 files changed

+195
-234
lines changed

extension/android/executorch_android/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ dependencies {
4949
implementation 'com.facebook.soloader:nativeloader:0.10.5'
5050
implementation libs.core.ktx
5151
testImplementation 'junit:junit:4.12'
52+
testImplementation 'org.assertj:assertj-core:3.27.2'
5253
androidTestImplementation 'androidx.test.ext:junit:1.1.5'
5354
androidTestImplementation 'androidx.test:rules:1.2.0'
5455
androidTestImplementation 'commons-io:commons-io:2.4'

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ public byte[] getDataAsByteArray() {
394394
*/
395395
public byte[] getDataAsUnsignedByteArray() {
396396
throw new IllegalStateException(
397-
"Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array.");
397+
"Tensor of type " + getClass().getSimpleName() + " cannot return data as unsigned byte array.");
398398
}
399399

400400
/**

extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt

Lines changed: 59 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
*/
88
package org.pytorch.executorch
99

10-
import org.junit.Assert
10+
import org.assertj.core.api.Assertions.assertThatThrownBy
11+
import org.junit.Assert.assertEquals
12+
import org.junit.Assert.assertFalse
13+
import org.junit.Assert.assertTrue
1114
import org.junit.Test
1215
import org.junit.runner.RunWith
1316
import org.junit.runners.JUnit4
@@ -18,91 +21,76 @@ class EValueTest {
1821
@Test
1922
fun testNone() {
2023
val evalue = EValue.optionalNone()
21-
Assert.assertTrue(evalue.isNone)
24+
assertTrue(evalue.isNone)
2225
}
2326

2427
@Test
2528
fun testTensorValue() {
2629
val data = longArrayOf(1, 2, 3)
2730
val shape = longArrayOf(1, 3)
2831
val evalue = EValue.from(Tensor.fromBlob(data, shape))
29-
Assert.assertTrue(evalue.isTensor)
30-
Assert.assertTrue(evalue.toTensor().shape.contentEquals(shape))
31-
Assert.assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data))
32+
assertTrue(evalue.isTensor)
33+
assertTrue(evalue.toTensor().shape.contentEquals(shape))
34+
assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data))
3235
}
3336

3437
@Test
3538
fun testBoolValue() {
3639
val evalue = EValue.from(true)
37-
Assert.assertTrue(evalue.isBool)
38-
Assert.assertTrue(evalue.toBool())
40+
assertTrue(evalue.isBool)
41+
assertTrue(evalue.toBool())
3942
}
4043

4144
@Test
4245
fun testIntValue() {
4346
val evalue = EValue.from(1)
44-
Assert.assertTrue(evalue.isInt)
45-
Assert.assertEquals(evalue.toInt(), 1)
47+
assertTrue(evalue.isInt)
48+
assertEquals(evalue.toInt(), 1)
4649
}
4750

4851
@Test
4952
fun testDoubleValue() {
5053
val evalue = EValue.from(0.1)
51-
Assert.assertTrue(evalue.isDouble)
52-
Assert.assertEquals(evalue.toDouble(), 0.1, 0.0001)
54+
assertTrue(evalue.isDouble)
55+
assertEquals(evalue.toDouble(), 0.1, 0.0001)
5356
}
5457

5558
@Test
5659
fun testStringValue() {
5760
val evalue = EValue.from("a")
58-
Assert.assertTrue(evalue.isString)
59-
Assert.assertEquals(evalue.toStr(), "a")
61+
assertTrue(evalue.isString)
62+
assertEquals(evalue.toStr(), "a")
6063
}
6164

6265
@Test
6366
fun testAllIllegalCast() {
6467
val evalue = EValue.optionalNone()
65-
Assert.assertTrue(evalue.isNone)
68+
assertTrue(evalue.isNone)
6669

6770
// try Tensor
68-
Assert.assertFalse(evalue.isTensor)
69-
try {
70-
evalue.toTensor()
71-
Assert.fail("Should have thrown an exception")
72-
} catch (e: IllegalStateException) {
73-
}
71+
assertFalse(evalue.isTensor)
72+
assertThatThrownBy {
73+
evalue.toTensor() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Tensor, actual type None")
7474

7575
// try bool
76-
Assert.assertFalse(evalue.isBool)
77-
try {
78-
evalue.toBool()
79-
Assert.fail("Should have thrown an exception")
80-
} catch (e: IllegalStateException) {
81-
}
76+
assertFalse(evalue.isBool)
77+
assertThatThrownBy {
78+
evalue.toBool() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Bool, actual type None")
8279

8380
// try int
84-
Assert.assertFalse(evalue.isInt)
85-
try {
86-
evalue.toInt()
87-
Assert.fail("Should have thrown an exception")
88-
} catch (e: IllegalStateException) {
89-
}
81+
assertFalse(evalue.isInt)
82+
assertThatThrownBy {
83+
evalue.toInt() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Int, actual type None")
9084

9185
// try double
92-
Assert.assertFalse(evalue.isDouble)
93-
try {
94-
evalue.toDouble()
95-
Assert.fail("Should have thrown an exception")
96-
} catch (e: IllegalStateException) {
97-
}
86+
assertFalse(evalue.isDouble)
87+
assertThatThrownBy {
88+
evalue.toDouble() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Double, actual type None")
9889

9990
// try string
100-
Assert.assertFalse(evalue.isString)
101-
try {
102-
evalue.toStr()
103-
Assert.fail("Should have thrown an exception")
104-
} catch (e: IllegalStateException) {
105-
}
91+
assertFalse(evalue.isString)
92+
assertThatThrownBy {
93+
evalue.toStr() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type String, actual type None")
10694
}
10795

10896
@Test
@@ -111,47 +99,47 @@ class EValueTest {
11199
val bytes = evalue.toByteArray()
112100

113101
val deser = EValue.fromByteArray(bytes)
114-
Assert.assertEquals(deser.isNone, true)
102+
assertEquals(deser.isNone, true)
115103
}
116104

117105
@Test
118106
fun testBoolSerde() {
119107
val evalue = EValue.from(true)
120108
val bytes = evalue.toByteArray()
121-
Assert.assertEquals(1, bytes[1].toLong())
109+
assertEquals(1, bytes[1].toLong())
122110

123111
val deser = EValue.fromByteArray(bytes)
124-
Assert.assertEquals(deser.isBool, true)
125-
Assert.assertEquals(deser.toBool(), true)
112+
assertEquals(deser.isBool, true)
113+
assertEquals(deser.toBool(), true)
126114
}
127115

128116
@Test
129117
fun testBoolSerde2() {
130118
val evalue = EValue.from(false)
131119
val bytes = evalue.toByteArray()
132-
Assert.assertEquals(0, bytes[1].toLong())
120+
assertEquals(0, bytes[1].toLong())
133121

134122
val deser = EValue.fromByteArray(bytes)
135-
Assert.assertEquals(deser.isBool, true)
136-
Assert.assertEquals(deser.toBool(), false)
123+
assertEquals(deser.isBool, true)
124+
assertEquals(deser.toBool(), false)
137125
}
138126

139127
@Test
140128
fun testIntSerde() {
141129
val evalue = EValue.from(1)
142130
val bytes = evalue.toByteArray()
143-
Assert.assertEquals(0, bytes[1].toLong())
144-
Assert.assertEquals(0, bytes[2].toLong())
145-
Assert.assertEquals(0, bytes[3].toLong())
146-
Assert.assertEquals(0, bytes[4].toLong())
147-
Assert.assertEquals(0, bytes[5].toLong())
148-
Assert.assertEquals(0, bytes[6].toLong())
149-
Assert.assertEquals(0, bytes[7].toLong())
150-
Assert.assertEquals(1, bytes[8].toLong())
131+
assertEquals(0, bytes[1].toLong())
132+
assertEquals(0, bytes[2].toLong())
133+
assertEquals(0, bytes[3].toLong())
134+
assertEquals(0, bytes[4].toLong())
135+
assertEquals(0, bytes[5].toLong())
136+
assertEquals(0, bytes[6].toLong())
137+
assertEquals(0, bytes[7].toLong())
138+
assertEquals(1, bytes[8].toLong())
151139

152140
val deser = EValue.fromByteArray(bytes)
153-
Assert.assertEquals(deser.isInt, true)
154-
Assert.assertEquals(deser.toInt(), 1)
141+
assertEquals(deser.isInt, true)
142+
assertEquals(deser.toInt(), 1)
155143
}
156144

157145
@Test
@@ -160,8 +148,8 @@ class EValueTest {
160148
val bytes = evalue.toByteArray()
161149

162150
val deser = EValue.fromByteArray(bytes)
163-
Assert.assertEquals(deser.isInt, true)
164-
Assert.assertEquals(deser.toInt(), 256000)
151+
assertEquals(deser.isInt, true)
152+
assertEquals(deser.toInt(), 256000)
165153
}
166154

167155
@Test
@@ -170,8 +158,8 @@ class EValueTest {
170158
val bytes = evalue.toByteArray()
171159

172160
val deser = EValue.fromByteArray(bytes)
173-
Assert.assertEquals(deser.isDouble, true)
174-
Assert.assertEquals(1.345e-2, deser.toDouble(), 1e-6)
161+
assertEquals(deser.isDouble, true)
162+
assertEquals(1.345e-2, deser.toDouble(), 1e-6)
175163
}
176164

177165
@Test
@@ -184,17 +172,17 @@ class EValueTest {
184172
val bytes = evalue.toByteArray()
185173

186174
val deser = EValue.fromByteArray(bytes)
187-
Assert.assertEquals(deser.isTensor, true)
175+
assertEquals(deser.isTensor, true)
188176
val deserTensor = deser.toTensor()
189177
val deserShape = deserTensor.shape()
190178
val deserData = deserTensor.dataAsLongArray
191179

192180
for (i in data.indices) {
193-
Assert.assertEquals(data[i], deserData[i])
181+
assertEquals(data[i], deserData[i])
194182
}
195183

196184
for (i in shape.indices) {
197-
Assert.assertEquals(shape[i], deserShape[i])
185+
assertEquals(shape[i], deserShape[i])
198186
}
199187
}
200188

@@ -208,17 +196,17 @@ class EValueTest {
208196
val bytes = evalue.toByteArray()
209197

210198
val deser = EValue.fromByteArray(bytes)
211-
Assert.assertEquals(deser.isTensor, true)
199+
assertEquals(deser.isTensor, true)
212200
val deserTensor = deser.toTensor()
213201
val deserShape = deserTensor.shape()
214202
val deserData = deserTensor.dataAsFloatArray
215203

216204
for (i in data.indices) {
217-
Assert.assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5)
205+
assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5)
218206
}
219207

220208
for (i in shape.indices) {
221-
Assert.assertEquals(shape[i], deserShape[i])
209+
assertEquals(shape[i], deserShape[i])
222210
}
223211
}
224212
}

0 commit comments

Comments
 (0)