diff --git a/extension/android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/src/main/java/org/pytorch/executorch/EValue.java index 016b6a3e097..f133eb4ad60 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/EValue.java +++ b/extension/android/src/main/java/org/pytorch/executorch/EValue.java @@ -9,6 +9,8 @@ package org.pytorch.executorch; import com.facebook.jni.annotations.DoNotStrip; +import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Locale; import java.util.Optional; import org.pytorch.executorch.annotations.Experimental; @@ -287,4 +289,75 @@ private void preconditionType(int typeCodeExpected, int typeCode) { private String getTypeName(int typeCode) { return typeCode >= 0 && typeCode < TYPE_NAMES.length ? TYPE_NAMES[typeCode] : "Unknown"; } + + /** + * Serializes an {@code EValue} into a byte array. + * + * @return The serialized byte array. + * @apiNote This method is experimental and subject to change without notice. This does NOT + * supoprt list type. + */ + public byte[] toByteArray() { + if (isNone()) { + return ByteBuffer.allocate(1).put((byte) TYPE_CODE_NONE).array(); + } else if (isTensor()) { + Tensor t = toTensor(); + byte[] tByteArray = t.toByteArray(); + return ByteBuffer.allocate(1 + tByteArray.length) + .put((byte) TYPE_CODE_TENSOR) + .put(tByteArray) + .array(); + } else if (isBool()) { + return ByteBuffer.allocate(2) + .put((byte) TYPE_CODE_BOOL) + .put((byte) (toBool() ? 1 : 0)) + .array(); + } else if (isInt()) { + return ByteBuffer.allocate(9).put((byte) TYPE_CODE_INT).putLong(toInt()).array(); + } else if (isDouble()) { + return ByteBuffer.allocate(9).put((byte) TYPE_CODE_DOUBLE).putDouble(toDouble()).array(); + } else if (isString()) { + return ByteBuffer.allocate(1 + toString().length()) + .put((byte) TYPE_CODE_STRING) + .put(toString().getBytes()) + .array(); + } else { + throw new IllegalArgumentException("Unknown Tensor dtype"); + } + } + + /** + * Deserializes an {@code EValue} from a byte[]. + * + * @param bytes The byte array to deserialize from. + * @return The deserialized {@code EValue}. + * @apiNote This method is experimental and subject to change without notice. This does NOT list + * type. + */ + public static EValue fromByteArray(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.wrap(bytes); + if (buffer == null) { + throw new IllegalArgumentException("buffer cannot be null"); + } + if (!buffer.hasRemaining()) { + throw new IllegalArgumentException("invalid buffer"); + } + int typeCode = buffer.get(); + switch (typeCode) { + case TYPE_CODE_NONE: + return new EValue(TYPE_CODE_NONE); + case TYPE_CODE_TENSOR: + byte[] bufferArray = buffer.array(); + return from(Tensor.fromByteArray(Arrays.copyOfRange(bufferArray, 1, bufferArray.length))); + case TYPE_CODE_STRING: + throw new IllegalArgumentException("TYPE_CODE_STRING is not supported"); + case TYPE_CODE_DOUBLE: + return from(buffer.getDouble()); + case TYPE_CODE_INT: + return from(buffer.getLong()); + case TYPE_CODE_BOOL: + return from(buffer.get() != 0); + } + throw new IllegalArgumentException("invalid type code: " + typeCode); + } } diff --git a/extension/android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/src/main/java/org/pytorch/executorch/Tensor.java index 685110ff9ae..f76a247a59a 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/src/main/java/org/pytorch/executorch/Tensor.java @@ -53,9 +53,10 @@ public abstract class Tensor { @DoNotStrip final long[] shape; + private static final int BYTE_SIZE_BYTES = 1; private static final int INT_SIZE_BYTES = 4; - private static final int FLOAT_SIZE_BYTES = 4; private static final int LONG_SIZE_BYTES = 8; + private static final int FLOAT_SIZE_BYTES = 4; private static final int DOUBLE_SIZE_BYTES = 8; /** @@ -679,4 +680,105 @@ private static Tensor nativeNewTensor( tensor.mHybridData = hybridData; return tensor; } + + /** + * Serializes a {@code Tensor} into a byte array. + * + * @return The serialized byte array. + * @apiNote This method is experimental and subject to change without notice. This does NOT + * supoprt list type. + */ + public byte[] toByteArray() { + int dtypeSize = 0; + byte[] tensorAsByteArray = null; + if (dtype() == DType.UINT8) { + dtypeSize = BYTE_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel()]; + Tensor_uint8 thiz = (Tensor_uint8) this; + ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsUnsignedByteArray()); + } else if (dtype() == DType.INT8) { + dtypeSize = BYTE_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel()]; + Tensor_int8 thiz = (Tensor_int8) this; + ByteBuffer.wrap(tensorAsByteArray).put(thiz.getDataAsByteArray()); + } else if (dtype() == DType.INT16) { + throw new IllegalArgumentException("DType.INT16 is not supported in Java so far"); + } else if (dtype() == DType.INT32) { + dtypeSize = INT_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_int32 thiz = (Tensor_int32) this; + ByteBuffer.wrap(tensorAsByteArray).asIntBuffer().put(thiz.getDataAsIntArray()); + } else if (dtype() == DType.INT64) { + dtypeSize = LONG_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_int64 thiz = (Tensor_int64) this; + ByteBuffer.wrap(tensorAsByteArray).asLongBuffer().put(thiz.getDataAsLongArray()); + } else if (dtype() == DType.FLOAT) { + dtypeSize = FLOAT_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_float32 thiz = (Tensor_float32) this; + ByteBuffer.wrap(tensorAsByteArray).asFloatBuffer().put(thiz.getDataAsFloatArray()); + } else if (dtype() == DType.DOUBLE) { + dtypeSize = DOUBLE_SIZE_BYTES; + tensorAsByteArray = new byte[(int) numel() * dtypeSize]; + Tensor_float64 thiz = (Tensor_float64) this; + ByteBuffer.wrap(tensorAsByteArray).asDoubleBuffer().put(thiz.getDataAsDoubleArray()); + } else { + throw new IllegalArgumentException("Unknown Tensor dtype"); + } + ByteBuffer byteBuffer = + ByteBuffer.allocate(1 + 1 + 4 * shape.length + dtypeSize * (int) numel()); + byteBuffer.put((byte) dtype().jniCode); + byteBuffer.put((byte) shape.length); + for (long s : shape) { + byteBuffer.putInt((int) s); + } + byteBuffer.put(tensorAsByteArray); + return byteBuffer.array(); + } + + /** + * Deserializes a {@code Tensor} from a byte[]. + * + * @param buffer The byte array to deserialize from. + * @return The deserialized {@code Tensor}. + * @apiNote This method is experimental and subject to change without notice. This does NOT + * supoprt list type. + */ + public static Tensor fromByteArray(byte[] bytes) { + if (bytes == null) { + throw new IllegalArgumentException("bytes cannot be null"); + } + ByteBuffer buffer = ByteBuffer.wrap(bytes); + if (!buffer.hasRemaining()) { + throw new IllegalArgumentException("invalid buffer"); + } + byte dtype = buffer.get(); + byte shapeLength = buffer.get(); + long[] shape = new long[(int) shapeLength]; + long numel = 1; + for (int i = 0; i < shapeLength; i++) { + int dim = buffer.getInt(); + if (dim < 0) { + throw new IllegalArgumentException("invalid shape"); + } + shape[i] = dim; + numel *= dim; + } + if (dtype == DType.UINT8.jniCode) { + return new Tensor_uint8(buffer, shape); + } else if (dtype == DType.INT8.jniCode) { + return new Tensor_int8(buffer, shape); + } else if (dtype == DType.INT32.jniCode) { + return new Tensor_int32(buffer.asIntBuffer(), shape); + } else if (dtype == DType.INT64.jniCode) { + return new Tensor_int64(buffer.asLongBuffer(), shape); + } else if (dtype == DType.FLOAT.jniCode) { + return new Tensor_float32(buffer.asFloatBuffer(), shape); + } else if (dtype == DType.DOUBLE.jniCode) { + return new Tensor_float64(buffer.asDoubleBuffer(), shape); + } else { + throw new IllegalArgumentException("Unknown Tensor dtype"); + } + } } diff --git a/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java index 29cabae75fa..9856329da78 100644 --- a/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java +++ b/extension/android_test/src/test/java/org/pytorch/executorch/EValueTest.java @@ -9,22 +9,12 @@ package org.pytorch.executorch; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import com.facebook.jni.annotations.DoNotStrip; - -import java.util.List; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Locale; import java.util.Optional; - -import org.pytorch.executorch.Tensor.Tensor_int64; -import org.pytorch.executorch.annotations.Experimental; - import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -32,7 +22,6 @@ /** Unit tests for {@link EValue}. */ @RunWith(JUnit4.class) public class EValueTest { - @Test public void testNone() { EValue evalue = EValue.optionalNone(); @@ -95,7 +84,7 @@ public void testIntListValue() { @Test public void testDoubleListValue() { - double[] value = {Double.MIN_VALUE,0.1d, 0.01d, 0.001d, Double.MAX_VALUE}; + double[] value = {Double.MIN_VALUE, 0.1d, 0.01d, 0.001d, Double.MAX_VALUE}; EValue evalue = EValue.listFrom(value); assertTrue(evalue.isDoubleList()); assertTrue(Arrays.equals(value, evalue.toDoubleList())); @@ -123,10 +112,11 @@ public void testOptionalTensorListValue() { long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}}; long[][] shape = {{1, 3}, {2, 3}}; - EValue evalue = EValue.listFrom( - Optional.empty(), - Optional.of(Tensor.fromBlob(data[0], shape[0])), - Optional.of(Tensor.fromBlob(data[1], shape[1]))); + EValue evalue = + EValue.listFrom( + Optional.empty(), + Optional.of(Tensor.fromBlob(data[0], shape[0])), + Optional.of(Tensor.fromBlob(data[1], shape[1]))); assertTrue(evalue.isOptionalTensorList()); assertTrue(!evalue.toOptionalTensorList()[0].isPresent()); @@ -144,75 +134,202 @@ public void testOptionalTensorListValue() { public void testAllIllegalCast() { EValue evalue = EValue.optionalNone(); assertTrue(evalue.isNone()); - + // try Tensor assertFalse(evalue.isTensor()); try { - evalue.toTensor(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toTensor(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try bool assertFalse(evalue.isBool()); try { - evalue.toBool(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toBool(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try int assertFalse(evalue.isInt()); try { - evalue.toInt(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toInt(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try double assertFalse(evalue.isDouble()); try { - evalue.toDouble(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toDouble(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try string assertFalse(evalue.isString()); try { - evalue.toStr(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toStr(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try bool list assertFalse(evalue.isBoolList()); try { - evalue.toBoolList(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toBoolList(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try int list assertFalse(evalue.isIntList()); try { - evalue.toIntList(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toIntList(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try double list assertFalse(evalue.isDoubleList()); try { - evalue.toBool(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toBool(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try Tensor list assertFalse(evalue.isTensorList()); try { - evalue.toTensorList(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} - + evalue.toTensorList(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + // try optional Tensor list assertFalse(evalue.isOptionalTensorList()); try { - evalue.toOptionalTensorList(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) {} + evalue.toOptionalTensorList(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) { + } + } + + @Test + public void testNoneSerde() { + EValue evalue = EValue.optionalNone(); + byte[] bytes = evalue.toByteArray(); + + EValue deser = EValue.fromByteArray(bytes); + assertEquals(deser.isNone(), true); + } + + @Test + public void testBoolSerde() { + EValue evalue = EValue.from(true); + byte[] bytes = evalue.toByteArray(); + assertEquals(1, bytes[1]); + + EValue deser = EValue.fromByteArray(bytes); + assertEquals(deser.isBool(), true); + assertEquals(deser.toBool(), true); + } + + @Test + public void testBoolSerde2() { + EValue evalue = EValue.from(false); + byte[] bytes = evalue.toByteArray(); + assertEquals(0, bytes[1]); + + EValue deser = EValue.fromByteArray(bytes); + assertEquals(deser.isBool(), true); + assertEquals(deser.toBool(), false); + } + + @Test + public void testIntSerde() { + EValue evalue = EValue.from(1); + byte[] bytes = evalue.toByteArray(); + assertEquals(0, bytes[1]); + assertEquals(0, bytes[2]); + assertEquals(0, bytes[3]); + assertEquals(0, bytes[4]); + assertEquals(0, bytes[5]); + assertEquals(0, bytes[6]); + assertEquals(0, bytes[7]); + assertEquals(1, bytes[8]); + + EValue deser = EValue.fromByteArray(bytes); + assertEquals(deser.isInt(), true); + assertEquals(deser.toInt(), 1); + } + + @Test + public void testLargeIntSerde() { + EValue evalue = EValue.from(256000); + byte[] bytes = evalue.toByteArray(); + + EValue deser = EValue.fromByteArray(bytes); + assertEquals(deser.isInt(), true); + assertEquals(deser.toInt(), 256000); + } + + @Test + public void testDoubleSerde() { + EValue evalue = EValue.from(1.345e-2d); + byte[] bytes = evalue.toByteArray(); + + EValue deser = EValue.fromByteArray(bytes); + assertEquals(deser.isDouble(), true); + assertEquals(1.345e-2d, deser.toDouble(), 1e-6); + } + + @Test + public void testLongTensorSerde() { + long data[] = {1, 2, 3, 4}; + long shape[] = {2, 2}; + Tensor tensor = Tensor.fromBlob(data, shape); + + EValue evalue = EValue.from(tensor); + byte[] bytes = evalue.toByteArray(); + + EValue deser = EValue.fromByteArray(bytes); + assertEquals(deser.isTensor(), true); + Tensor deserTensor = deser.toTensor(); + long[] deserShape = deserTensor.shape(); + long[] deserData = deserTensor.getDataAsLongArray(); + + for (int i = 0; i < data.length; i++) { + assertEquals(data[i], deserData[i]); + } + + for (int i = 0; i < shape.length; i++) { + assertEquals(shape[i], deserShape[i]); + } + } + + @Test + public void testFloatTensorSerde() { + float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; + long shape[] = {2, 2}; + Tensor tensor = Tensor.fromBlob(data, shape); + + EValue evalue = EValue.from(tensor); + byte[] bytes = evalue.toByteArray(); + + EValue deser = EValue.fromByteArray(bytes); + assertEquals(deser.isTensor(), true); + Tensor deserTensor = deser.toTensor(); + long[] deserShape = deserTensor.shape(); + float[] deserData = deserTensor.getDataAsFloatArray(); + + for (int i = 0; i < data.length; i++) { + assertEquals(data[i], deserData[i], 1e-5); + } + + for (int i = 0; i < shape.length; i++) { + assertEquals(shape[i], deserShape[i]); + } } } diff --git a/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java b/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java index 7933113412c..9811a1d0ff6 100644 --- a/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java +++ b/extension/android_test/src/test/java/org/pytorch/executorch/TensorTest.java @@ -9,9 +9,6 @@ package org.pytorch.executorch; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.fail; import java.nio.ByteBuffer; @@ -19,11 +16,9 @@ import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; - import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.pytorch.executorch.Tensor; /** Unit tests for {@link Tensor}. */ @RunWith(JUnit4.class) @@ -243,28 +238,68 @@ public void testIllegalArguments() { long mismatchShape[] = {1, 2}; try { - Tensor tensor = Tensor.fromBlob((float[]) null, mismatchShape); + Tensor tensor = Tensor.fromBlob((float[]) null, mismatchShape); fail("Should have thrown an exception"); } catch (IllegalArgumentException e) { // expected } try { - Tensor tensor = Tensor.fromBlob(data, null); + Tensor tensor = Tensor.fromBlob(data, null); fail("Should have thrown an exception"); } catch (IllegalArgumentException e) { // expected } try { - Tensor tensor = Tensor.fromBlob(data, shapeWithNegativeValues); + Tensor tensor = Tensor.fromBlob(data, shapeWithNegativeValues); fail("Should have thrown an exception"); } catch (IllegalArgumentException e) { // expected } try { - Tensor tensor = Tensor.fromBlob(data, mismatchShape); + Tensor tensor = Tensor.fromBlob(data, mismatchShape); fail("Should have thrown an exception"); } catch (IllegalArgumentException e) { // expected } } + + @Test + public void testLongTensorSerde() { + long data[] = {1, 2, 3, 4}; + long shape[] = {2, 2}; + Tensor tensor = Tensor.fromBlob(data, shape); + byte[] bytes = tensor.toByteArray(); + + Tensor deser = Tensor.fromByteArray(bytes); + long[] deserShape = deser.shape(); + long[] deserData = deser.getDataAsLongArray(); + + for (int i = 0; i < data.length; i++) { + assertEquals(data[i], deserData[i]); + } + + for (int i = 0; i < shape.length; i++) { + assertEquals(shape[i], deserShape[i]); + } + } + + @Test + public void testFloatTensorSerde() { + float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; + long shape[] = {2, 2}; + Tensor tensor = Tensor.fromBlob(data, shape); + byte[] bytes = tensor.toByteArray(); + + Tensor deser = Tensor.fromByteArray(bytes); + long[] deserShape = deser.shape(); + float[] deserData = deser.getDataAsFloatArray(); + + for (int i = 0; i < data.length; i++) { + assertEquals(data[i], deserData[i], 1e-5); + } + + for (int i = 0; i < shape.length; i++) { + assertEquals(shape[i], deserShape[i]); + } + } }