@@ -53,9 +53,10 @@ public abstract class Tensor {
5353
5454 @ DoNotStrip final long [] shape ;
5555
56+ private static final int BYTE_SIZE_BYTES = 1 ;
5657 private static final int INT_SIZE_BYTES = 4 ;
57- private static final int FLOAT_SIZE_BYTES = 4 ;
5858 private static final int LONG_SIZE_BYTES = 8 ;
59+ private static final int FLOAT_SIZE_BYTES = 4 ;
5960 private static final int DOUBLE_SIZE_BYTES = 8 ;
6061
6162 /**
@@ -679,4 +680,105 @@ private static Tensor nativeNewTensor(
679680 tensor .mHybridData = hybridData ;
680681 return tensor ;
681682 }
683+
684+ /**
685+ * Serializes a {@code Tensor} into a byte array.
686+ *
687+ * @return The serialized byte array.
688+ * @apiNote This method is experimental and subject to change without notice. This does NOT
689+ * supoprt list type.
690+ */
691+ public byte [] toByteArray () {
692+ int dtypeSize = 0 ;
693+ byte [] tensorAsByteArray = null ;
694+ if (dtype () == DType .UINT8 ) {
695+ dtypeSize = BYTE_SIZE_BYTES ;
696+ tensorAsByteArray = new byte [(int ) numel ()];
697+ Tensor_uint8 thiz = (Tensor_uint8 ) this ;
698+ ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsUnsignedByteArray ());
699+ } else if (dtype () == DType .INT8 ) {
700+ dtypeSize = BYTE_SIZE_BYTES ;
701+ tensorAsByteArray = new byte [(int ) numel ()];
702+ Tensor_int8 thiz = (Tensor_int8 ) this ;
703+ ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsByteArray ());
704+ } else if (dtype () == DType .INT16 ) {
705+ throw new IllegalArgumentException ("DType.INT16 is not supported in Java so far" );
706+ } else if (dtype () == DType .INT32 ) {
707+ dtypeSize = INT_SIZE_BYTES ;
708+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
709+ Tensor_int32 thiz = (Tensor_int32 ) this ;
710+ ByteBuffer .wrap (tensorAsByteArray ).asIntBuffer ().put (thiz .getDataAsIntArray ());
711+ } else if (dtype () == DType .INT64 ) {
712+ dtypeSize = LONG_SIZE_BYTES ;
713+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
714+ Tensor_int64 thiz = (Tensor_int64 ) this ;
715+ ByteBuffer .wrap (tensorAsByteArray ).asLongBuffer ().put (thiz .getDataAsLongArray ());
716+ } else if (dtype () == DType .FLOAT ) {
717+ dtypeSize = FLOAT_SIZE_BYTES ;
718+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
719+ Tensor_float32 thiz = (Tensor_float32 ) this ;
720+ ByteBuffer .wrap (tensorAsByteArray ).asFloatBuffer ().put (thiz .getDataAsFloatArray ());
721+ } else if (dtype () == DType .DOUBLE ) {
722+ dtypeSize = DOUBLE_SIZE_BYTES ;
723+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
724+ Tensor_float64 thiz = (Tensor_float64 ) this ;
725+ ByteBuffer .wrap (tensorAsByteArray ).asDoubleBuffer ().put (thiz .getDataAsDoubleArray ());
726+ } else {
727+ throw new IllegalArgumentException ("Unknown Tensor dtype" );
728+ }
729+ ByteBuffer byteBuffer =
730+ ByteBuffer .allocate (1 + 1 + 4 * shape .length + dtypeSize * (int ) numel ());
731+ byteBuffer .put ((byte ) dtype ().jniCode );
732+ byteBuffer .put ((byte ) shape .length );
733+ for (long s : shape ) {
734+ byteBuffer .putInt ((int ) s );
735+ }
736+ byteBuffer .put (tensorAsByteArray );
737+ return byteBuffer .array ();
738+ }
739+
740+ /**
741+ * Deserializes a {@code Tensor} from a byte[].
742+ *
743+ * @param buffer The byte array to deserialize from.
744+ * @return The deserialized {@code Tensor}.
745+ * @apiNote This method is experimental and subject to change without notice. This does NOT
746+ * supoprt list type.
747+ */
748+ public static Tensor fromByteArray (byte [] bytes ) {
749+ if (bytes == null ) {
750+ throw new IllegalArgumentException ("bytes cannot be null" );
751+ }
752+ ByteBuffer buffer = ByteBuffer .wrap (bytes );
753+ if (!buffer .hasRemaining ()) {
754+ throw new IllegalArgumentException ("invalid buffer" );
755+ }
756+ byte dtype = buffer .get ();
757+ byte shapeLength = buffer .get ();
758+ long [] shape = new long [(int ) shapeLength ];
759+ long numel = 1 ;
760+ for (int i = 0 ; i < shapeLength ; i ++) {
761+ int dim = buffer .getInt ();
762+ if (dim < 0 ) {
763+ throw new IllegalArgumentException ("invalid shape" );
764+ }
765+ shape [i ] = dim ;
766+ numel *= dim ;
767+ }
768+ if (dtype == DType .UINT8 .jniCode ) {
769+ return new Tensor_uint8 (buffer , shape );
770+ } else if (dtype == DType .INT8 .jniCode ) {
771+ return new Tensor_int8 (buffer , shape );
772+ } else if (dtype == DType .INT32 .jniCode ) {
773+ return new Tensor_int32 (buffer .asIntBuffer (), shape );
774+ } else if (dtype == DType .INT64 .jniCode ) {
775+ return new Tensor_int64 (buffer .asLongBuffer (), shape );
776+ } else if (dtype == DType .FLOAT .jniCode ) {
777+ return new Tensor_float32 (buffer .asFloatBuffer (), shape );
778+ } else if (dtype == DType .DOUBLE .jniCode ) {
779+ return new Tensor_float64 (buffer .asDoubleBuffer (), shape );
780+ } else {
781+ throw new IllegalArgumentException ("Unknown Tensor dtype" );
782+ }
783+ }
682784}
0 commit comments