From e8216f8a4ab0fc011dd292a8df4c412b3eae2141 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 24 Apr 2025 17:10:24 -0700 Subject: [PATCH] Java test use kotlin --- .../android/executorch_android/build.gradle | 5 + .../org/pytorch/executorch/EValueTest.java | 230 ------------- .../java/org/pytorch/executorch/EValueTest.kt | 224 +++++++++++++ .../org/pytorch/executorch/TensorTest.java | 305 ------------------ .../java/org/pytorch/executorch/TensorTest.kt | 296 +++++++++++++++++ extension/android/gradle/libs.versions.toml | 5 + 6 files changed, 530 insertions(+), 535 deletions(-) delete mode 100644 extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.java create mode 100644 extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt delete mode 100644 extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.java create mode 100644 extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt diff --git a/extension/android/executorch_android/build.gradle b/extension/android/executorch_android/build.gradle index 15088f4097f..6fd07027dda 100644 --- a/extension/android/executorch_android/build.gradle +++ b/extension/android/executorch_android/build.gradle @@ -9,6 +9,7 @@ plugins { id "com.android.library" version "8.9.0" id "com.vanniktech.maven.publish" version "0.31.0" + alias(libs.plugins.jetbrains.kotlin.android) } android { @@ -34,6 +35,9 @@ android { resources.srcDirs += [ 'src/androidTest/resources' ] } } + kotlinOptions { + jvmTarget = '1.8' + } } task copyTestRes(type: Exec) { @@ -43,6 +47,7 @@ task copyTestRes(type: Exec) { dependencies { implementation 'com.facebook.fbjni:fbjni:0.5.1' implementation 'com.facebook.soloader:nativeloader:0.10.5' + implementation libs.core.ktx testImplementation 'junit:junit:4.12' androidTestImplementation 'androidx.test.ext:junit:1.1.5' androidTestImplementation 'androidx.test:rules:1.2.0' diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.java b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.java deleted file mode 100644 index cbeb3a7b634..00000000000 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.java +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -import java.util.Arrays; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Unit tests for {@link EValue}. */ -@RunWith(JUnit4.class) -public class EValueTest { - @Test - public void testNone() { - EValue evalue = EValue.optionalNone(); - assertTrue(evalue.isNone()); - } - - @Test - public void testTensorValue() { - long[] data = {1, 2, 3}; - long[] shape = {1, 3}; - EValue evalue = EValue.from(Tensor.fromBlob(data, shape)); - assertTrue(evalue.isTensor()); - assertTrue(Arrays.equals(evalue.toTensor().shape, shape)); - assertTrue(Arrays.equals(evalue.toTensor().getDataAsLongArray(), data)); - } - - @Test - public void testBoolValue() { - EValue evalue = EValue.from(true); - assertTrue(evalue.isBool()); - assertTrue(evalue.toBool()); - } - - @Test - public void testIntValue() { - EValue evalue = EValue.from(1); - assertTrue(evalue.isInt()); - assertEquals(evalue.toInt(), 1); - } - - @Test - public void testDoubleValue() { - EValue evalue = EValue.from(0.1d); - assertTrue(evalue.isDouble()); - assertEquals(evalue.toDouble(), 0.1d, 0.0001d); - } - - @Test - public void testStringValue() { - EValue evalue = EValue.from("a"); - assertTrue(evalue.isString()); - assertEquals(evalue.toStr(), "a"); - } - - @Test - public void testAllIllegalCast() { - EValue evalue = EValue.optionalNone(); - assertTrue(evalue.isNone()); - - // try Tensor - assertFalse(evalue.isTensor()); - try { - evalue.toTensor(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - - // try bool - assertFalse(evalue.isBool()); - try { - evalue.toBool(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - - // try int - assertFalse(evalue.isInt()); - try { - evalue.toInt(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - - // try double - assertFalse(evalue.isDouble()); - try { - evalue.toDouble(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - - // try string - assertFalse(evalue.isString()); - try { - evalue.toStr(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - } - } - - @Test - public void testNoneSerde() { - EValue evalue = EValue.optionalNone(); - byte[] bytes = evalue.toByteArray(); - - EValue deser = EValue.fromByteArray(bytes); - assertEquals(deser.isNone(), true); - } - - @Test - public void testBoolSerde() { - EValue evalue = EValue.from(true); - byte[] bytes = evalue.toByteArray(); - assertEquals(1, bytes[1]); - - EValue deser = EValue.fromByteArray(bytes); - assertEquals(deser.isBool(), true); - assertEquals(deser.toBool(), true); - } - - @Test - public void testBoolSerde2() { - EValue evalue = EValue.from(false); - byte[] bytes = evalue.toByteArray(); - assertEquals(0, bytes[1]); - - EValue deser = EValue.fromByteArray(bytes); - assertEquals(deser.isBool(), true); - assertEquals(deser.toBool(), false); - } - - @Test - public void testIntSerde() { - EValue evalue = EValue.from(1); - byte[] bytes = evalue.toByteArray(); - assertEquals(0, bytes[1]); - assertEquals(0, bytes[2]); - assertEquals(0, bytes[3]); - assertEquals(0, bytes[4]); - assertEquals(0, bytes[5]); - assertEquals(0, bytes[6]); - assertEquals(0, bytes[7]); - assertEquals(1, bytes[8]); - - EValue deser = EValue.fromByteArray(bytes); - assertEquals(deser.isInt(), true); - assertEquals(deser.toInt(), 1); - } - - @Test - public void testLargeIntSerde() { - EValue evalue = EValue.from(256000); - byte[] bytes = evalue.toByteArray(); - - EValue deser = EValue.fromByteArray(bytes); - assertEquals(deser.isInt(), true); - assertEquals(deser.toInt(), 256000); - } - - @Test - public void testDoubleSerde() { - EValue evalue = EValue.from(1.345e-2d); - byte[] bytes = evalue.toByteArray(); - - EValue deser = EValue.fromByteArray(bytes); - assertEquals(deser.isDouble(), true); - assertEquals(1.345e-2d, deser.toDouble(), 1e-6); - } - - @Test - public void testLongTensorSerde() { - long data[] = {1, 2, 3, 4}; - long shape[] = {2, 2}; - Tensor tensor = Tensor.fromBlob(data, shape); - - EValue evalue = EValue.from(tensor); - byte[] bytes = evalue.toByteArray(); - - EValue deser = EValue.fromByteArray(bytes); - assertEquals(deser.isTensor(), true); - Tensor deserTensor = deser.toTensor(); - long[] deserShape = deserTensor.shape(); - long[] deserData = deserTensor.getDataAsLongArray(); - - for (int i = 0; i < data.length; i++) { - assertEquals(data[i], deserData[i]); - } - - for (int i = 0; i < shape.length; i++) { - assertEquals(shape[i], deserShape[i]); - } - } - - @Test - public void testFloatTensorSerde() { - float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; - long shape[] = {2, 2}; - Tensor tensor = Tensor.fromBlob(data, shape); - - EValue evalue = EValue.from(tensor); - byte[] bytes = evalue.toByteArray(); - - EValue deser = EValue.fromByteArray(bytes); - assertEquals(deser.isTensor(), true); - Tensor deserTensor = deser.toTensor(); - long[] deserShape = deserTensor.shape(); - float[] deserData = deserTensor.getDataAsFloatArray(); - - for (int i = 0; i < data.length; i++) { - assertEquals(data[i], deserData[i], 1e-5); - } - - for (int i = 0; i < shape.length; i++) { - assertEquals(shape[i], deserShape[i]); - } - } -} diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt new file mode 100644 index 00000000000..0e56480d621 --- /dev/null +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt @@ -0,0 +1,224 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +package org.pytorch.executorch + +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +/** Unit tests for [EValue]. */ +@RunWith(JUnit4::class) +class EValueTest { + @Test + fun testNone() { + val evalue = EValue.optionalNone() + Assert.assertTrue(evalue.isNone) + } + + @Test + fun testTensorValue() { + val data = longArrayOf(1, 2, 3) + val shape = longArrayOf(1, 3) + val evalue = EValue.from(Tensor.fromBlob(data, shape)) + Assert.assertTrue(evalue.isTensor) + Assert.assertTrue(evalue.toTensor().shape.contentEquals(shape)) + Assert.assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) + } + + @Test + fun testBoolValue() { + val evalue = EValue.from(true) + Assert.assertTrue(evalue.isBool) + Assert.assertTrue(evalue.toBool()) + } + + @Test + fun testIntValue() { + val evalue = EValue.from(1) + Assert.assertTrue(evalue.isInt) + Assert.assertEquals(evalue.toInt(), 1) + } + + @Test + fun testDoubleValue() { + val evalue = EValue.from(0.1) + Assert.assertTrue(evalue.isDouble) + Assert.assertEquals(evalue.toDouble(), 0.1, 0.0001) + } + + @Test + fun testStringValue() { + val evalue = EValue.from("a") + Assert.assertTrue(evalue.isString) + Assert.assertEquals(evalue.toStr(), "a") + } + + @Test + fun testAllIllegalCast() { + val evalue = EValue.optionalNone() + Assert.assertTrue(evalue.isNone) + + // try Tensor + Assert.assertFalse(evalue.isTensor) + try { + evalue.toTensor() + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + } + + // try bool + Assert.assertFalse(evalue.isBool) + try { + evalue.toBool() + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + } + + // try int + Assert.assertFalse(evalue.isInt) + try { + evalue.toInt() + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + } + + // try double + Assert.assertFalse(evalue.isDouble) + try { + evalue.toDouble() + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + } + + // try string + Assert.assertFalse(evalue.isString) + try { + evalue.toStr() + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + } + } + + @Test + fun testNoneSerde() { + val evalue = EValue.optionalNone() + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + Assert.assertEquals(deser.isNone, true) + } + + @Test + fun testBoolSerde() { + val evalue = EValue.from(true) + val bytes = evalue.toByteArray() + Assert.assertEquals(1, bytes[1].toLong()) + + val deser = EValue.fromByteArray(bytes) + Assert.assertEquals(deser.isBool, true) + Assert.assertEquals(deser.toBool(), true) + } + + @Test + fun testBoolSerde2() { + val evalue = EValue.from(false) + val bytes = evalue.toByteArray() + Assert.assertEquals(0, bytes[1].toLong()) + + val deser = EValue.fromByteArray(bytes) + Assert.assertEquals(deser.isBool, true) + Assert.assertEquals(deser.toBool(), false) + } + + @Test + fun testIntSerde() { + val evalue = EValue.from(1) + val bytes = evalue.toByteArray() + Assert.assertEquals(0, bytes[1].toLong()) + Assert.assertEquals(0, bytes[2].toLong()) + Assert.assertEquals(0, bytes[3].toLong()) + Assert.assertEquals(0, bytes[4].toLong()) + Assert.assertEquals(0, bytes[5].toLong()) + Assert.assertEquals(0, bytes[6].toLong()) + Assert.assertEquals(0, bytes[7].toLong()) + Assert.assertEquals(1, bytes[8].toLong()) + + val deser = EValue.fromByteArray(bytes) + Assert.assertEquals(deser.isInt, true) + Assert.assertEquals(deser.toInt(), 1) + } + + @Test + fun testLargeIntSerde() { + val evalue = EValue.from(256000) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + Assert.assertEquals(deser.isInt, true) + Assert.assertEquals(deser.toInt(), 256000) + } + + @Test + fun testDoubleSerde() { + val evalue = EValue.from(1.345e-2) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + Assert.assertEquals(deser.isDouble, true) + Assert.assertEquals(1.345e-2, deser.toDouble(), 1e-6) + } + + @Test + fun testLongTensorSerde() { + val data = longArrayOf(1, 2, 3, 4) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + + val evalue = EValue.from(tensor) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + Assert.assertEquals(deser.isTensor, true) + val deserTensor = deser.toTensor() + val deserShape = deserTensor.shape() + val deserData = deserTensor.dataAsLongArray + + for (i in data.indices) { + Assert.assertEquals(data[i], deserData[i]) + } + + for (i in shape.indices) { + Assert.assertEquals(shape[i], deserShape[i]) + } + } + + @Test + fun testFloatTensorSerde() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + + val evalue = EValue.from(tensor) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + Assert.assertEquals(deser.isTensor, true) + val deserTensor = deser.toTensor() + val deserShape = deserTensor.shape() + val deserData = deserTensor.dataAsFloatArray + + for (i in data.indices) { + Assert.assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) + } + + for (i in shape.indices) { + Assert.assertEquals(shape[i], deserShape[i]) + } + } +} diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.java b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.java deleted file mode 100644 index 9811a1d0ff6..00000000000 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.java +++ /dev/null @@ -1,305 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -package org.pytorch.executorch; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Unit tests for {@link Tensor}. */ -@RunWith(JUnit4.class) -public class TensorTest { - - @Test - public void testFloatTensor() { - float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; - long shape[] = {2, 2}; - Tensor tensor = Tensor.fromBlob(data, shape); - assertEquals(tensor.dtype(), DType.FLOAT); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsFloatArray()[0], 1e-5); - assertEquals(data[1], tensor.getDataAsFloatArray()[1], 1e-5); - assertEquals(data[2], tensor.getDataAsFloatArray()[2], 1e-5); - assertEquals(data[3], tensor.getDataAsFloatArray()[3], 1e-5); - - FloatBuffer floatBuffer = Tensor.allocateFloatBuffer(4); - floatBuffer.put(data); - tensor = Tensor.fromBlob(floatBuffer, shape); - assertEquals(tensor.dtype(), DType.FLOAT); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsFloatArray()[0], 1e-5); - assertEquals(data[1], tensor.getDataAsFloatArray()[1], 1e-5); - assertEquals(data[2], tensor.getDataAsFloatArray()[2], 1e-5); - assertEquals(data[3], tensor.getDataAsFloatArray()[3], 1e-5); - } - - @Test - public void testIntTensor() { - int data[] = {Integer.MIN_VALUE, 0, 1, Integer.MAX_VALUE}; - long shape[] = {1, 4, 1}; - Tensor tensor = Tensor.fromBlob(data, shape); - assertEquals(tensor.dtype(), DType.INT32); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(shape[2], tensor.shape()[2]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsIntArray()[0]); - assertEquals(data[1], tensor.getDataAsIntArray()[1]); - assertEquals(data[2], tensor.getDataAsIntArray()[2]); - assertEquals(data[3], tensor.getDataAsIntArray()[3]); - - IntBuffer intBuffer = Tensor.allocateIntBuffer(4); - intBuffer.put(data); - tensor = Tensor.fromBlob(intBuffer, shape); - assertEquals(tensor.dtype(), DType.INT32); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(shape[2], tensor.shape()[2]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsIntArray()[0]); - assertEquals(data[1], tensor.getDataAsIntArray()[1]); - assertEquals(data[2], tensor.getDataAsIntArray()[2]); - assertEquals(data[3], tensor.getDataAsIntArray()[3]); - } - - @Test - public void testDoubleTensor() { - double data[] = {Double.MIN_VALUE, 0.0d, 0.1d, Double.MAX_VALUE}; - long shape[] = {1, 4}; - Tensor tensor = Tensor.fromBlob(data, shape); - assertEquals(tensor.dtype(), DType.DOUBLE); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsDoubleArray()[0], 1e-5); - assertEquals(data[1], tensor.getDataAsDoubleArray()[1], 1e-5); - assertEquals(data[2], tensor.getDataAsDoubleArray()[2], 1e-5); - assertEquals(data[3], tensor.getDataAsDoubleArray()[3], 1e-5); - - DoubleBuffer doubleBuffer = Tensor.allocateDoubleBuffer(4); - doubleBuffer.put(data); - tensor = Tensor.fromBlob(doubleBuffer, shape); - assertEquals(tensor.dtype(), DType.DOUBLE); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsDoubleArray()[0], 1e-5); - assertEquals(data[1], tensor.getDataAsDoubleArray()[1], 1e-5); - assertEquals(data[2], tensor.getDataAsDoubleArray()[2], 1e-5); - assertEquals(data[3], tensor.getDataAsDoubleArray()[3], 1e-5); - } - - @Test - public void testLongTensor() { - long data[] = {Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE}; - long shape[] = {4, 1}; - Tensor tensor = Tensor.fromBlob(data, shape); - assertEquals(tensor.dtype(), DType.INT64); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsLongArray()[0]); - assertEquals(data[1], tensor.getDataAsLongArray()[1]); - assertEquals(data[2], tensor.getDataAsLongArray()[2]); - assertEquals(data[3], tensor.getDataAsLongArray()[3]); - - LongBuffer longBuffer = Tensor.allocateLongBuffer(4); - longBuffer.put(data); - tensor = Tensor.fromBlob(longBuffer, shape); - assertEquals(tensor.dtype(), DType.INT64); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsLongArray()[0]); - assertEquals(data[1], tensor.getDataAsLongArray()[1]); - assertEquals(data[2], tensor.getDataAsLongArray()[2]); - assertEquals(data[3], tensor.getDataAsLongArray()[3]); - } - - @Test - public void testSignedByteTensor() { - byte data[] = {Byte.MIN_VALUE, (byte) 0, (byte) 1, Byte.MAX_VALUE}; - long shape[] = {1, 1, 4}; - Tensor tensor = Tensor.fromBlob(data, shape); - assertEquals(tensor.dtype(), DType.INT8); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(shape[2], tensor.shape()[2]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsByteArray()[0]); - assertEquals(data[1], tensor.getDataAsByteArray()[1]); - assertEquals(data[2], tensor.getDataAsByteArray()[2]); - assertEquals(data[3], tensor.getDataAsByteArray()[3]); - - ByteBuffer byteBuffer = Tensor.allocateByteBuffer(4); - byteBuffer.put(data); - tensor = Tensor.fromBlob(byteBuffer, shape); - assertEquals(tensor.dtype(), DType.INT8); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(shape[2], tensor.shape()[2]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsByteArray()[0]); - assertEquals(data[1], tensor.getDataAsByteArray()[1]); - assertEquals(data[2], tensor.getDataAsByteArray()[2]); - assertEquals(data[3], tensor.getDataAsByteArray()[3]); - } - - @Test - public void testUnsignedByteTensor() { - byte data[] = {(byte) 0, (byte) 1, (byte) 2, (byte) 255}; - long shape[] = {4, 1, 1}; - Tensor tensor = Tensor.fromBlobUnsigned(data, shape); - assertEquals(tensor.dtype(), DType.UINT8); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(shape[2], tensor.shape()[2]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsUnsignedByteArray()[0]); - assertEquals(data[1], tensor.getDataAsUnsignedByteArray()[1]); - assertEquals(data[2], tensor.getDataAsUnsignedByteArray()[2]); - assertEquals(data[3], tensor.getDataAsUnsignedByteArray()[3]); - - ByteBuffer byteBuffer = Tensor.allocateByteBuffer(4); - byteBuffer.put(data); - tensor = Tensor.fromBlobUnsigned(byteBuffer, shape); - assertEquals(tensor.dtype(), DType.UINT8); - assertEquals(shape[0], tensor.shape()[0]); - assertEquals(shape[1], tensor.shape()[1]); - assertEquals(shape[2], tensor.shape()[2]); - assertEquals(4, tensor.numel()); - assertEquals(data[0], tensor.getDataAsUnsignedByteArray()[0]); - assertEquals(data[1], tensor.getDataAsUnsignedByteArray()[1]); - assertEquals(data[2], tensor.getDataAsUnsignedByteArray()[2]); - assertEquals(data[3], tensor.getDataAsUnsignedByteArray()[3]); - } - - @Test - public void testIllegalDataTypeException() { - float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; - long shape[] = {2, 2}; - Tensor tensor = Tensor.fromBlob(data, shape); - assertEquals(tensor.dtype(), DType.FLOAT); - - try { - tensor.getDataAsByteArray(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - // expected - } - try { - tensor.getDataAsUnsignedByteArray(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - // expected - } - try { - tensor.getDataAsIntArray(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - // expected - } - try { - tensor.getDataAsDoubleArray(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - // expected - } - try { - tensor.getDataAsLongArray(); - fail("Should have thrown an exception"); - } catch (IllegalStateException e) { - // expected - } - } - - @Test - public void testIllegalArguments() { - float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; - long shapeWithNegativeValues[] = {-1, 2}; - long mismatchShape[] = {1, 2}; - - try { - Tensor tensor = Tensor.fromBlob((float[]) null, mismatchShape); - fail("Should have thrown an exception"); - } catch (IllegalArgumentException e) { - // expected - } - try { - Tensor tensor = Tensor.fromBlob(data, null); - fail("Should have thrown an exception"); - } catch (IllegalArgumentException e) { - // expected - } - try { - Tensor tensor = Tensor.fromBlob(data, shapeWithNegativeValues); - fail("Should have thrown an exception"); - } catch (IllegalArgumentException e) { - // expected - } - try { - Tensor tensor = Tensor.fromBlob(data, mismatchShape); - fail("Should have thrown an exception"); - } catch (IllegalArgumentException e) { - // expected - } - } - - @Test - public void testLongTensorSerde() { - long data[] = {1, 2, 3, 4}; - long shape[] = {2, 2}; - Tensor tensor = Tensor.fromBlob(data, shape); - byte[] bytes = tensor.toByteArray(); - - Tensor deser = Tensor.fromByteArray(bytes); - long[] deserShape = deser.shape(); - long[] deserData = deser.getDataAsLongArray(); - - for (int i = 0; i < data.length; i++) { - assertEquals(data[i], deserData[i]); - } - - for (int i = 0; i < shape.length; i++) { - assertEquals(shape[i], deserShape[i]); - } - } - - @Test - public void testFloatTensorSerde() { - float data[] = {Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE}; - long shape[] = {2, 2}; - Tensor tensor = Tensor.fromBlob(data, shape); - byte[] bytes = tensor.toByteArray(); - - Tensor deser = Tensor.fromByteArray(bytes); - long[] deserShape = deser.shape(); - float[] deserData = deser.getDataAsFloatArray(); - - for (int i = 0; i < data.length; i++) { - assertEquals(data[i], deserData[i], 1e-5); - } - - for (int i = 0; i < shape.length; i++) { - assertEquals(shape[i], deserShape[i]); - } - } -} diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt new file mode 100644 index 00000000000..4b206c8efbd --- /dev/null +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt @@ -0,0 +1,296 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +package org.pytorch.executorch + +import org.junit.Assert +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +/** Unit tests for [Tensor]. */ +@RunWith(JUnit4::class) +class TensorTest { + @Test + fun testFloatTensor() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + var tensor = Tensor.fromBlob(data, shape) + Assert.assertEquals(tensor.dtype(), DType.FLOAT) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + Assert.assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + Assert.assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + Assert.assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + + val floatBuffer = Tensor.allocateFloatBuffer(4) + floatBuffer.put(data) + tensor = Tensor.fromBlob(floatBuffer, shape) + Assert.assertEquals(tensor.dtype(), DType.FLOAT) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + Assert.assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + Assert.assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + Assert.assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + } + + @Test + fun testIntTensor() { + val data = intArrayOf(Int.MIN_VALUE, 0, 1, Int.MAX_VALUE) + val shape = longArrayOf(1, 4, 1) + var tensor = Tensor.fromBlob(data, shape) + Assert.assertEquals(tensor.dtype(), DType.INT32) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(shape[2], tensor.shape()[2]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + Assert.assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + Assert.assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + Assert.assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + + val intBuffer = Tensor.allocateIntBuffer(4) + intBuffer.put(data) + tensor = Tensor.fromBlob(intBuffer, shape) + Assert.assertEquals(tensor.dtype(), DType.INT32) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(shape[2], tensor.shape()[2]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + Assert.assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + Assert.assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + Assert.assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + } + + @Test + fun testDoubleTensor() { + val data = doubleArrayOf(Double.MIN_VALUE, 0.0, 0.1, Double.MAX_VALUE) + val shape = longArrayOf(1, 4) + var tensor = Tensor.fromBlob(data, shape) + Assert.assertEquals(tensor.dtype(), DType.DOUBLE) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + Assert.assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + Assert.assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + Assert.assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + + val doubleBuffer = Tensor.allocateDoubleBuffer(4) + doubleBuffer.put(data) + tensor = Tensor.fromBlob(doubleBuffer, shape) + Assert.assertEquals(tensor.dtype(), DType.DOUBLE) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + Assert.assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + Assert.assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + Assert.assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + } + + @Test + fun testLongTensor() { + val data = longArrayOf(Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE) + val shape = longArrayOf(4, 1) + var tensor = Tensor.fromBlob(data, shape) + Assert.assertEquals(tensor.dtype(), DType.INT64) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0], tensor.dataAsLongArray[0]) + Assert.assertEquals(data[1], tensor.dataAsLongArray[1]) + Assert.assertEquals(data[2], tensor.dataAsLongArray[2]) + Assert.assertEquals(data[3], tensor.dataAsLongArray[3]) + + val longBuffer = Tensor.allocateLongBuffer(4) + longBuffer.put(data) + tensor = Tensor.fromBlob(longBuffer, shape) + Assert.assertEquals(tensor.dtype(), DType.INT64) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0], tensor.dataAsLongArray[0]) + Assert.assertEquals(data[1], tensor.dataAsLongArray[1]) + Assert.assertEquals(data[2], tensor.dataAsLongArray[2]) + Assert.assertEquals(data[3], tensor.dataAsLongArray[3]) + } + + @Test + fun testSignedByteTensor() { + val data = byteArrayOf(Byte.MIN_VALUE, 0.toByte(), 1.toByte(), Byte.MAX_VALUE) + val shape = longArrayOf(1, 1, 4) + var tensor = Tensor.fromBlob(data, shape) + Assert.assertEquals(tensor.dtype(), DType.INT8) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(shape[2], tensor.shape()[2]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + Assert.assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + Assert.assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + Assert.assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + + val byteBuffer = Tensor.allocateByteBuffer(4) + byteBuffer.put(data) + tensor = Tensor.fromBlob(byteBuffer, shape) + Assert.assertEquals(tensor.dtype(), DType.INT8) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(shape[2], tensor.shape()[2]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + Assert.assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + Assert.assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + Assert.assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + } + + @Test + fun testUnsignedByteTensor() { + val data = byteArrayOf(0.toByte(), 1.toByte(), 2.toByte(), 255.toByte()) + val shape = longArrayOf(4, 1, 1) + var tensor = Tensor.fromBlobUnsigned(data, shape) + Assert.assertEquals(tensor.dtype(), DType.UINT8) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(shape[2], tensor.shape()[2]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + Assert.assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + Assert.assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + Assert.assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + + val byteBuffer = Tensor.allocateByteBuffer(4) + byteBuffer.put(data) + tensor = Tensor.fromBlobUnsigned(byteBuffer, shape) + Assert.assertEquals(tensor.dtype(), DType.UINT8) + Assert.assertEquals(shape[0], tensor.shape()[0]) + Assert.assertEquals(shape[1], tensor.shape()[1]) + Assert.assertEquals(shape[2], tensor.shape()[2]) + Assert.assertEquals(4, tensor.numel()) + Assert.assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + Assert.assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + Assert.assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + Assert.assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + } + + @Test + fun testIllegalDataTypeException() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + Assert.assertEquals(tensor.dtype(), DType.FLOAT) + + try { + tensor.dataAsByteArray + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + // expected + } + try { + tensor.dataAsUnsignedByteArray + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + // expected + } + try { + tensor.dataAsIntArray + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + // expected + } + try { + tensor.dataAsDoubleArray + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + // expected + } + try { + tensor.dataAsLongArray + Assert.fail("Should have thrown an exception") + } catch (e: IllegalStateException) { + // expected + } + } + + @Test + fun testIllegalArguments() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shapeWithNegativeValues = longArrayOf(-1, 2) + val mismatchShape = longArrayOf(1, 2) + + try { + val tensor = Tensor.fromBlob(null as FloatArray?, mismatchShape) + Assert.fail("Should have thrown an exception") + } catch (e: IllegalArgumentException) { + // expected + } + try { + val tensor = Tensor.fromBlob(data, null) + Assert.fail("Should have thrown an exception") + } catch (e: IllegalArgumentException) { + // expected + } + try { + val tensor = Tensor.fromBlob(data, shapeWithNegativeValues) + Assert.fail("Should have thrown an exception") + } catch (e: IllegalArgumentException) { + // expected + } + try { + val tensor = Tensor.fromBlob(data, mismatchShape) + Assert.fail("Should have thrown an exception") + } catch (e: IllegalArgumentException) { + // expected + } + } + + @Test + fun testLongTensorSerde() { + val data = longArrayOf(1, 2, 3, 4) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + val bytes = tensor.toByteArray() + + val deser = Tensor.fromByteArray(bytes) + val deserShape = deser.shape() + val deserData = deser.dataAsLongArray + + for (i in data.indices) { + Assert.assertEquals(data[i], deserData[i]) + } + + for (i in shape.indices) { + Assert.assertEquals(shape[i], deserShape[i]) + } + } + + @Test + fun testFloatTensorSerde() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + val bytes = tensor.toByteArray() + + val deser = Tensor.fromByteArray(bytes) + val deserShape = deser.shape() + val deserData = deser.dataAsFloatArray + + for (i in data.indices) { + Assert.assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) + } + + for (i in shape.indices) { + Assert.assertEquals(shape[i], deserShape[i]) + } + } +} diff --git a/extension/android/gradle/libs.versions.toml b/extension/android/gradle/libs.versions.toml index 561988cb1f6..dd2cf3f039a 100644 --- a/extension/android/gradle/libs.versions.toml +++ b/extension/android/gradle/libs.versions.toml @@ -5,8 +5,13 @@ commons-math3 = "3.6.1" guava = "32.1.3-jre" junit = "4.13.2" +core-ktx = "1.13.1" +kotlin = "2.1.20" [libraries] commons-math3 = { module = "org.apache.commons:commons-math3", version.ref = "commons-math3" } guava = { module = "com.google.guava:guava", version.ref = "guava" } junit = { module = "junit:junit", version.ref = "junit" } +core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" } +[plugins] +jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }