77 */
88package org.pytorch.executorch
99
10- import org.junit.Assert
10+ import org.assertj.core.api.Assertions.assertThatThrownBy
11+ import org.junit.Assert.assertEquals
12+ import org.junit.Assert.assertFalse
13+ import org.junit.Assert.assertTrue
1114import org.junit.Test
1215import org.junit.runner.RunWith
1316import org.junit.runners.JUnit4
@@ -18,91 +21,76 @@ class EValueTest {
1821 @Test
1922 fun testNone () {
2023 val evalue = EValue .optionalNone()
21- Assert . assertTrue(evalue.isNone)
24+ assertTrue(evalue.isNone)
2225 }
2326
2427 @Test
2528 fun testTensorValue () {
2629 val data = longArrayOf(1 , 2 , 3 )
2730 val shape = longArrayOf(1 , 3 )
2831 val evalue = EValue .from(Tensor .fromBlob(data, shape))
29- Assert . assertTrue(evalue.isTensor)
30- Assert . assertTrue(evalue.toTensor().shape.contentEquals(shape))
31- Assert . assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data))
32+ assertTrue(evalue.isTensor)
33+ assertTrue(evalue.toTensor().shape.contentEquals(shape))
34+ assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data))
3235 }
3336
3437 @Test
3538 fun testBoolValue () {
3639 val evalue = EValue .from(true )
37- Assert . assertTrue(evalue.isBool)
38- Assert . assertTrue(evalue.toBool())
40+ assertTrue(evalue.isBool)
41+ assertTrue(evalue.toBool())
3942 }
4043
4144 @Test
4245 fun testIntValue () {
4346 val evalue = EValue .from(1 )
44- Assert . assertTrue(evalue.isInt)
45- Assert . assertEquals(evalue.toInt(), 1 )
47+ assertTrue(evalue.isInt)
48+ assertEquals(evalue.toInt(), 1 )
4649 }
4750
4851 @Test
4952 fun testDoubleValue () {
5053 val evalue = EValue .from(0.1 )
51- Assert . assertTrue(evalue.isDouble)
52- Assert . assertEquals(evalue.toDouble(), 0.1 , 0.0001 )
54+ assertTrue(evalue.isDouble)
55+ assertEquals(evalue.toDouble(), 0.1 , 0.0001 )
5356 }
5457
5558 @Test
5659 fun testStringValue () {
5760 val evalue = EValue .from(" a" )
58- Assert . assertTrue(evalue.isString)
59- Assert . assertEquals(evalue.toStr(), " a" )
61+ assertTrue(evalue.isString)
62+ assertEquals(evalue.toStr(), " a" )
6063 }
6164
6265 @Test
6366 fun testAllIllegalCast () {
6467 val evalue = EValue .optionalNone()
65- Assert . assertTrue(evalue.isNone)
68+ assertTrue(evalue.isNone)
6669
6770 // try Tensor
68- Assert .assertFalse(evalue.isTensor)
69- try {
70- evalue.toTensor()
71- Assert .fail(" Should have thrown an exception" )
72- } catch (e: IllegalStateException ) {
73- }
71+ assertFalse(evalue.isTensor)
72+ assertThatThrownBy {
73+ evalue.toTensor() }.isInstanceOf(IllegalStateException ::class .java).hasMessage(" Expected EValue type Tensor, actual type None" )
7474
7575 // try bool
76- Assert .assertFalse(evalue.isBool)
77- try {
78- evalue.toBool()
79- Assert .fail(" Should have thrown an exception" )
80- } catch (e: IllegalStateException ) {
81- }
76+ assertFalse(evalue.isBool)
77+ assertThatThrownBy {
78+ evalue.toBool() }.isInstanceOf(IllegalStateException ::class .java).hasMessage(" Expected EValue type Bool, actual type None" )
8279
8380 // try int
84- Assert .assertFalse(evalue.isInt)
85- try {
86- evalue.toInt()
87- Assert .fail(" Should have thrown an exception" )
88- } catch (e: IllegalStateException ) {
89- }
81+ assertFalse(evalue.isInt)
82+ assertThatThrownBy {
83+ evalue.toInt() }.isInstanceOf(IllegalStateException ::class .java).hasMessage(" Expected EValue type Int, actual type None" )
9084
9185 // try double
92- Assert .assertFalse(evalue.isDouble)
93- try {
94- evalue.toDouble()
95- Assert .fail(" Should have thrown an exception" )
96- } catch (e: IllegalStateException ) {
97- }
86+ assertFalse(evalue.isDouble)
87+ assertThatThrownBy {
88+ evalue.toDouble() }.isInstanceOf(IllegalStateException ::class .java).hasMessage(" Expected EValue type Double, actual type None" )
9889
9990 // try string
100- Assert .assertFalse(evalue.isString)
101- try {
102- evalue.toStr()
103- Assert .fail(" Should have thrown an exception" )
104- } catch (e: IllegalStateException ) {
105- }
91+ assertFalse(evalue.isString)
92+ assertThatThrownBy {
93+ evalue.toStr() }.isInstanceOf(IllegalStateException ::class .java).hasMessage(" Expected EValue type String, actual type None" )
10694 }
10795
10896 @Test
@@ -111,47 +99,47 @@ class EValueTest {
11199 val bytes = evalue.toByteArray()
112100
113101 val deser = EValue .fromByteArray(bytes)
114- Assert . assertEquals(deser.isNone, true )
102+ assertEquals(deser.isNone, true )
115103 }
116104
117105 @Test
118106 fun testBoolSerde () {
119107 val evalue = EValue .from(true )
120108 val bytes = evalue.toByteArray()
121- Assert . assertEquals(1 , bytes[1 ].toLong())
109+ assertEquals(1 , bytes[1 ].toLong())
122110
123111 val deser = EValue .fromByteArray(bytes)
124- Assert . assertEquals(deser.isBool, true )
125- Assert . assertEquals(deser.toBool(), true )
112+ assertEquals(deser.isBool, true )
113+ assertEquals(deser.toBool(), true )
126114 }
127115
128116 @Test
129117 fun testBoolSerde2 () {
130118 val evalue = EValue .from(false )
131119 val bytes = evalue.toByteArray()
132- Assert . assertEquals(0 , bytes[1 ].toLong())
120+ assertEquals(0 , bytes[1 ].toLong())
133121
134122 val deser = EValue .fromByteArray(bytes)
135- Assert . assertEquals(deser.isBool, true )
136- Assert . assertEquals(deser.toBool(), false )
123+ assertEquals(deser.isBool, true )
124+ assertEquals(deser.toBool(), false )
137125 }
138126
139127 @Test
140128 fun testIntSerde () {
141129 val evalue = EValue .from(1 )
142130 val bytes = evalue.toByteArray()
143- Assert . assertEquals(0 , bytes[1 ].toLong())
144- Assert . assertEquals(0 , bytes[2 ].toLong())
145- Assert . assertEquals(0 , bytes[3 ].toLong())
146- Assert . assertEquals(0 , bytes[4 ].toLong())
147- Assert . assertEquals(0 , bytes[5 ].toLong())
148- Assert . assertEquals(0 , bytes[6 ].toLong())
149- Assert . assertEquals(0 , bytes[7 ].toLong())
150- Assert . assertEquals(1 , bytes[8 ].toLong())
131+ assertEquals(0 , bytes[1 ].toLong())
132+ assertEquals(0 , bytes[2 ].toLong())
133+ assertEquals(0 , bytes[3 ].toLong())
134+ assertEquals(0 , bytes[4 ].toLong())
135+ assertEquals(0 , bytes[5 ].toLong())
136+ assertEquals(0 , bytes[6 ].toLong())
137+ assertEquals(0 , bytes[7 ].toLong())
138+ assertEquals(1 , bytes[8 ].toLong())
151139
152140 val deser = EValue .fromByteArray(bytes)
153- Assert . assertEquals(deser.isInt, true )
154- Assert . assertEquals(deser.toInt(), 1 )
141+ assertEquals(deser.isInt, true )
142+ assertEquals(deser.toInt(), 1 )
155143 }
156144
157145 @Test
@@ -160,8 +148,8 @@ class EValueTest {
160148 val bytes = evalue.toByteArray()
161149
162150 val deser = EValue .fromByteArray(bytes)
163- Assert . assertEquals(deser.isInt, true )
164- Assert . assertEquals(deser.toInt(), 256000 )
151+ assertEquals(deser.isInt, true )
152+ assertEquals(deser.toInt(), 256000 )
165153 }
166154
167155 @Test
@@ -170,8 +158,8 @@ class EValueTest {
170158 val bytes = evalue.toByteArray()
171159
172160 val deser = EValue .fromByteArray(bytes)
173- Assert . assertEquals(deser.isDouble, true )
174- Assert . assertEquals(1 .345e- 2 , deser.toDouble(), 1e- 6 )
161+ assertEquals(deser.isDouble, true )
162+ assertEquals(1 .345e- 2 , deser.toDouble(), 1e- 6 )
175163 }
176164
177165 @Test
@@ -184,17 +172,17 @@ class EValueTest {
184172 val bytes = evalue.toByteArray()
185173
186174 val deser = EValue .fromByteArray(bytes)
187- Assert . assertEquals(deser.isTensor, true )
175+ assertEquals(deser.isTensor, true )
188176 val deserTensor = deser.toTensor()
189177 val deserShape = deserTensor.shape()
190178 val deserData = deserTensor.dataAsLongArray
191179
192180 for (i in data.indices) {
193- Assert . assertEquals(data[i], deserData[i])
181+ assertEquals(data[i], deserData[i])
194182 }
195183
196184 for (i in shape.indices) {
197- Assert . assertEquals(shape[i], deserShape[i])
185+ assertEquals(shape[i], deserShape[i])
198186 }
199187 }
200188
@@ -208,17 +196,17 @@ class EValueTest {
208196 val bytes = evalue.toByteArray()
209197
210198 val deser = EValue .fromByteArray(bytes)
211- Assert . assertEquals(deser.isTensor, true )
199+ assertEquals(deser.isTensor, true )
212200 val deserTensor = deser.toTensor()
213201 val deserShape = deserTensor.shape()
214202 val deserData = deserTensor.dataAsFloatArray
215203
216204 for (i in data.indices) {
217- Assert . assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e- 5 )
205+ assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e- 5 )
218206 }
219207
220208 for (i in shape.indices) {
221- Assert . assertEquals(shape[i], deserShape[i])
209+ assertEquals(shape[i], deserShape[i])
222210 }
223211 }
224212}
0 commit comments