@@ -679,4 +679,105 @@ private static Tensor nativeNewTensor(
679679 tensor .mHybridData = hybridData ;
680680 return tensor ;
681681 }
682+
683+ /**
684+ * Serializes a {@code Tensor} into a byte array.
685+ *
686+ * @return The serialized byte array.
687+ * @apiNote This method is experimental and subject to change without notice. This does NOT
688+ * supoprt list type.
689+ */
690+ public byte [] toByteArray () {
691+ int dtypeSize = 0 ;
692+ byte [] tensorAsByteArray = null ;
693+ if (dtype () == DType .FLOAT ) {
694+ dtypeSize = 4 ;
695+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
696+ Tensor_float32 thiz = (Tensor_float32 ) this ;
697+ ByteBuffer .wrap (tensorAsByteArray ).asFloatBuffer ().put (thiz .getDataAsFloatArray ());
698+ } else if (dtype () == DType .DOUBLE ) {
699+ dtypeSize = 8 ;
700+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
701+ Tensor_float64 thiz = (Tensor_float64 ) this ;
702+ ByteBuffer .wrap (tensorAsByteArray ).asDoubleBuffer ().put (thiz .getDataAsDoubleArray ());
703+ } else if (dtype () == DType .UINT8 ) {
704+ dtypeSize = 1 ;
705+ tensorAsByteArray = new byte [(int ) numel ()];
706+ Tensor_uint8 thiz = (Tensor_uint8 ) this ;
707+ ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsUnsignedByteArray ());
708+ } else if (dtype () == DType .INT8 ) {
709+ dtypeSize = 1 ;
710+ tensorAsByteArray = new byte [(int ) numel ()];
711+ Tensor_int8 thiz = (Tensor_int8 ) this ;
712+ ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsByteArray ());
713+ } else if (dtype () == DType .INT16 ) {
714+ throw new IllegalArgumentException ("DType.INT16 is not supported in Java so far" );
715+ } else if (dtype () == DType .INT32 ) {
716+ dtypeSize = 4 ;
717+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
718+ Tensor_int32 thiz = (Tensor_int32 ) this ;
719+ ByteBuffer .wrap (tensorAsByteArray ).asIntBuffer ().put (thiz .getDataAsIntArray ());
720+ } else if (dtype () == DType .INT64 ) {
721+ dtypeSize = 8 ;
722+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
723+ Tensor_int64 thiz = (Tensor_int64 ) this ;
724+ ByteBuffer .wrap (tensorAsByteArray ).asLongBuffer ().put (thiz .getDataAsLongArray ());
725+ } else {
726+ throw new IllegalArgumentException ("Unknown Tensor dtype" );
727+ }
728+ ByteBuffer byteBuffer =
729+ ByteBuffer .allocate (1 + 1 + 4 * shape .length + dtypeSize * (int ) numel ());
730+ byteBuffer .put ((byte ) dtype ().jniCode );
731+ byteBuffer .put ((byte ) shape .length );
732+ for (long s : shape ) {
733+ byteBuffer .putInt ((int ) s );
734+ }
735+ byteBuffer .put (tensorAsByteArray );
736+ return byteBuffer .array ();
737+ }
738+
739+ /**
740+ * Deserializes a {@code Tensor} from a byte[].
741+ *
742+ * @param buffer The byte array to deserialize from.
743+ * @return The deserialized {@code Tensor}.
744+ * @apiNote This method is experimental and subject to change without notice. This does NOT
745+ * supoprt list type.
746+ */
747+ public static Tensor fromByteArray (byte [] bytes ) {
748+ if (bytes == null ) {
749+ throw new IllegalArgumentException ("bytes cannot be null" );
750+ }
751+ ByteBuffer buffer = ByteBuffer .wrap (bytes );
752+ if (!buffer .hasRemaining ()) {
753+ throw new IllegalArgumentException ("invalid buffer" );
754+ }
755+ byte scalarType = buffer .get ();
756+ byte numberOfDimensions = buffer .get ();
757+ long [] shape = new long [(int ) numberOfDimensions ];
758+ long numel = 1 ;
759+ for (int i = 0 ; i < numberOfDimensions ; i ++) {
760+ int dim = buffer .getInt ();
761+ if (dim < 0 ) {
762+ throw new IllegalArgumentException ("invalid shape" );
763+ }
764+ shape [i ] = dim ;
765+ numel *= dim ;
766+ }
767+ if (scalarType == DType .FLOAT .jniCode ) {
768+ return new Tensor_float32 (buffer .asFloatBuffer (), shape );
769+ } else if (scalarType == DType .DOUBLE .jniCode ) {
770+ return new Tensor_float64 (buffer .asDoubleBuffer (), shape );
771+ } else if (scalarType == DType .UINT8 .jniCode ) {
772+ return new Tensor_uint8 (buffer , shape );
773+ } else if (scalarType == DType .INT8 .jniCode ) {
774+ return new Tensor_int8 (buffer , shape );
775+ } else if (scalarType == DType .INT16 .jniCode ) {
776+ return new Tensor_int32 (buffer .asIntBuffer (), shape );
777+ } else if (scalarType == DType .INT64 .jniCode ) {
778+ return new Tensor_int64 (buffer .asLongBuffer (), shape );
779+ } else {
780+ throw new IllegalArgumentException ("Unknown Tensor dtype" );
781+ }
782+ }
682783}
0 commit comments