Skip to content

Commit 6f8a73d

Browse files
committed
Update
[ghstack-poisoned]
1 parent 317e571 commit 6f8a73d

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

extension/android/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ task makeJar(type: Jar) {
2020
dependencies {
2121
implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2'
2222
implementation 'com.facebook.soloader:nativeloader:0.10.5'
23+
testImplementation 'junit:junit:4.13.2'
2324
}
2425
}
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import static org.junit.Assert.assertEquals;
12+
import static org.junit.Assert.assertTrue;
13+
import static org.junit.Assert.assertFalse;
14+
import static org.junit.Assert.assertNotEquals;
15+
16+
import com.facebook.jni.annotations.DoNotStrip;
17+
18+
import java.util.List;
19+
import java.util.ArrayList;
20+
import java.util.Arrays;
21+
import java.util.Locale;
22+
import java.util.Optional;
23+
24+
import org.pytorch.executorch.Tensor.Tensor_int64;
25+
import org.pytorch.executorch.annotations.Experimental;
26+
27+
import org.junit.Test;
28+
import org.junit.runner.RunWith;
29+
import org.junit.runners.JUnit4;
30+
31+
/** Unit tests for {@link EValue}. */
32+
@RunWith(JUnit4.class)
33+
public class EValueTest {
34+
35+
@Test
36+
public void testNone() {
37+
EValue evalue = EValue.optionalNone();
38+
assertTrue(evalue.isNone());
39+
}
40+
41+
@Test
42+
public void testTensorValue() {
43+
long[] data = {1, 2, 3};
44+
long[] shape = {1, 3};
45+
EValue evalue = EValue.from(Tensor.fromBlob(data, shape));
46+
assertTrue(evalue.isTensor());
47+
assertTrue(Arrays.equals(evalue.toTensor().shape, shape));
48+
assertTrue(Arrays.equals(evalue.toTensor().getDataAsLongArray(), data));
49+
}
50+
51+
@Test
52+
public void testBoolValue() {
53+
EValue evalue = EValue.from(true);
54+
assertTrue(evalue.isBool());
55+
assertTrue(evalue.toBool());
56+
}
57+
58+
@Test
59+
public void testIntValue() {
60+
EValue evalue = EValue.from(1);
61+
assertTrue(evalue.isInt());
62+
assertEquals(evalue.toInt(), 1);
63+
}
64+
65+
@Test
66+
public void testDoubleValue() {
67+
EValue evalue = EValue.from(0.1d);
68+
assertTrue(evalue.isDouble());
69+
assertEquals(evalue.toDouble(), 0.1d, 0.0001d);
70+
}
71+
72+
@Test
73+
public void testStringValue() {
74+
EValue evalue = EValue.from("a");
75+
assertTrue(evalue.isString());
76+
assertEquals(evalue.toStr(), "a");
77+
}
78+
79+
@Test
80+
public void testBoolListValue() {
81+
boolean[] value = {true, false, true};
82+
EValue evalue = EValue.listFrom(value);
83+
assertTrue(evalue.isBoolList());
84+
assertTrue(Arrays.equals(value, evalue.toBoolList()));
85+
}
86+
87+
@Test
88+
public void testIntListValue() {
89+
long[] value = {Long.MIN_VALUE, 0, Long.MAX_VALUE};
90+
EValue evalue = EValue.listFrom(value);
91+
assertTrue(evalue.isIntList());
92+
assertTrue(Arrays.equals(value, evalue.toIntList()));
93+
}
94+
95+
@Test
96+
public void testDoubleListValue() {
97+
double[] value = {Double.MIN_VALUE,0.1d, 0.01d, 0.001d, Double.MAX_VALUE};
98+
EValue evalue = EValue.listFrom(value);
99+
assertTrue(evalue.isDoubleList());
100+
assertTrue(Arrays.equals(value, evalue.toDoubleList()));
101+
}
102+
103+
@Test
104+
public void testTensorListValue() {
105+
long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}};
106+
long[][] shape = {{1, 3}, {2, 3}};
107+
Tensor[] tensors = {Tensor.fromBlob(data[0], shape[0]), Tensor.fromBlob(data[1], shape[1])};
108+
109+
EValue evalue = EValue.listFrom(tensors);
110+
assertTrue(evalue.isTensorList());
111+
112+
assertTrue(Arrays.equals(evalue.toTensorList()[0].shape, shape[0]));
113+
assertTrue(Arrays.equals(evalue.toTensorList()[0].getDataAsLongArray(), data[0]));
114+
115+
assertTrue(Arrays.equals(evalue.toTensorList()[1].shape, shape[1]));
116+
assertTrue(Arrays.equals(evalue.toTensorList()[1].getDataAsLongArray(), data[1]));
117+
}
118+
119+
@Test
120+
@SuppressWarnings("unchecked")
121+
public void testOptionalTensorListValue() {
122+
long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}};
123+
long[][] shape = {{1, 3}, {2, 3}};
124+
125+
EValue evalue = EValue.listFrom(
126+
Optional.<Tensor>empty(),
127+
Optional.of(Tensor.fromBlob(data[0], shape[0])),
128+
Optional.of(Tensor.fromBlob(data[1], shape[1])));
129+
assertTrue(evalue.isOptionalTensorList());
130+
131+
assertTrue(evalue.toOptionalTensorList()[0].isEmpty());
132+
133+
assertTrue(evalue.toOptionalTensorList()[1].isPresent());
134+
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0]));
135+
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().getDataAsLongArray(), data[0]));
136+
137+
assertTrue(evalue.toOptionalTensorList()[2].isPresent());
138+
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().shape, shape[1]));
139+
assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().getDataAsLongArray(), data[1]));
140+
}
141+
142+
@Test
143+
public void testAllIllegalCast() {
144+
EValue evalue = EValue.optionalNone();
145+
assertTrue(evalue.isNone());
146+
147+
// try Tensor
148+
assertFalse(evalue.isTensor());
149+
try {
150+
evalue.toTensor();
151+
} catch (IllegalStateException e) {}
152+
153+
// try bool
154+
assertFalse(evalue.isBool());
155+
try {
156+
evalue.toBool();
157+
} catch (IllegalStateException e) {}
158+
159+
// try int
160+
assertFalse(evalue.isInt());
161+
try {
162+
evalue.toInt();
163+
} catch (IllegalStateException e) {}
164+
165+
// try double
166+
assertFalse(evalue.isDouble());
167+
try {
168+
evalue.toDouble();
169+
} catch (IllegalStateException e) {}
170+
171+
// try string
172+
assertFalse(evalue.isString());
173+
try {
174+
evalue.toStr();
175+
} catch (IllegalStateException e) {}
176+
177+
// try bool list
178+
assertFalse(evalue.isBoolList());
179+
try {
180+
evalue.toBoolList();
181+
} catch (IllegalStateException e) {}
182+
183+
// try int list
184+
assertFalse(evalue.isIntList());
185+
try {
186+
evalue.toIntList();
187+
} catch (IllegalStateException e) {}
188+
189+
// try double list
190+
assertFalse(evalue.isDoubleList());
191+
try {
192+
evalue.toBool();
193+
} catch (IllegalStateException e) {}
194+
195+
// try Tensor list
196+
assertFalse(evalue.isTensorList());
197+
try {
198+
evalue.toTensorList();
199+
} catch (IllegalStateException e) {}
200+
201+
// try optional Tensor list
202+
assertFalse(evalue.isOptionalTensorList());
203+
try {
204+
evalue.toOptionalTensorList();
205+
} catch (IllegalStateException e) {}
206+
}
207+
}

0 commit comments

Comments
 (0)