@@ -53,9 +53,10 @@ public abstract class Tensor {
5353
5454 @ DoNotStrip final long [] shape ;
5555
56+ private static final int BYTE_SIZE_BYTES = 1 ;
5657 private static final int INT_SIZE_BYTES = 4 ;
57- private static final int FLOAT_SIZE_BYTES = 4 ;
5858 private static final int LONG_SIZE_BYTES = 8 ;
59+ private static final int FLOAT_SIZE_BYTES = 4 ;
5960 private static final int DOUBLE_SIZE_BYTES = 8 ;
6061
6162 /**
@@ -690,38 +691,38 @@ private static Tensor nativeNewTensor(
690691 public byte [] toByteArray () {
691692 int dtypeSize = 0 ;
692693 byte [] tensorAsByteArray = null ;
693- if (dtype () == DType .FLOAT ) {
694- dtypeSize = 4 ;
695- tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
696- Tensor_float32 thiz = (Tensor_float32 ) this ;
697- ByteBuffer .wrap (tensorAsByteArray ).asFloatBuffer ().put (thiz .getDataAsFloatArray ());
698- } else if (dtype () == DType .DOUBLE ) {
699- dtypeSize = 8 ;
700- tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
701- Tensor_float64 thiz = (Tensor_float64 ) this ;
702- ByteBuffer .wrap (tensorAsByteArray ).asDoubleBuffer ().put (thiz .getDataAsDoubleArray ());
703- } else if (dtype () == DType .UINT8 ) {
704- dtypeSize = 1 ;
694+ if (dtype () == DType .UINT8 ) {
695+ dtypeSize = BYTE_SIZE_BYTES ;
705696 tensorAsByteArray = new byte [(int ) numel ()];
706697 Tensor_uint8 thiz = (Tensor_uint8 ) this ;
707698 ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsUnsignedByteArray ());
708699 } else if (dtype () == DType .INT8 ) {
709- dtypeSize = 1 ;
700+ dtypeSize = BYTE_SIZE_BYTES ;
710701 tensorAsByteArray = new byte [(int ) numel ()];
711702 Tensor_int8 thiz = (Tensor_int8 ) this ;
712703 ByteBuffer .wrap (tensorAsByteArray ).put (thiz .getDataAsByteArray ());
713704 } else if (dtype () == DType .INT16 ) {
714705 throw new IllegalArgumentException ("DType.INT16 is not supported in Java so far" );
715706 } else if (dtype () == DType .INT32 ) {
716- dtypeSize = 4 ;
707+ dtypeSize = INT_SIZE_BYTES ;
717708 tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
718709 Tensor_int32 thiz = (Tensor_int32 ) this ;
719710 ByteBuffer .wrap (tensorAsByteArray ).asIntBuffer ().put (thiz .getDataAsIntArray ());
720711 } else if (dtype () == DType .INT64 ) {
721- dtypeSize = 8 ;
712+ dtypeSize = LONG_SIZE_BYTES ;
722713 tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
723714 Tensor_int64 thiz = (Tensor_int64 ) this ;
724715 ByteBuffer .wrap (tensorAsByteArray ).asLongBuffer ().put (thiz .getDataAsLongArray ());
716+ } else if (dtype () == DType .FLOAT ) {
717+ dtypeSize = FLOAT_SIZE_BYTES ;
718+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
719+ Tensor_float32 thiz = (Tensor_float32 ) this ;
720+ ByteBuffer .wrap (tensorAsByteArray ).asFloatBuffer ().put (thiz .getDataAsFloatArray ());
721+ } else if (dtype () == DType .DOUBLE ) {
722+ dtypeSize = DOUBLE_SIZE_BYTES ;
723+ tensorAsByteArray = new byte [(int ) numel () * dtypeSize ];
724+ Tensor_float64 thiz = (Tensor_float64 ) this ;
725+ ByteBuffer .wrap (tensorAsByteArray ).asDoubleBuffer ().put (thiz .getDataAsDoubleArray ());
725726 } else {
726727 throw new IllegalArgumentException ("Unknown Tensor dtype" );
727728 }
@@ -752,30 +753,30 @@ public static Tensor fromByteArray(byte[] bytes) {
752753 if (!buffer .hasRemaining ()) {
753754 throw new IllegalArgumentException ("invalid buffer" );
754755 }
755- byte scalarType = buffer .get ();
756- byte numberOfDimensions = buffer .get ();
757- long [] shape = new long [(int ) numberOfDimensions ];
756+ byte dtype = buffer .get ();
757+ byte shapeLength = buffer .get ();
758+ long [] shape = new long [(int ) shapeLength ];
758759 long numel = 1 ;
759- for (int i = 0 ; i < numberOfDimensions ; i ++) {
760+ for (int i = 0 ; i < shapeLength ; i ++) {
760761 int dim = buffer .getInt ();
761762 if (dim < 0 ) {
762763 throw new IllegalArgumentException ("invalid shape" );
763764 }
764765 shape [i ] = dim ;
765766 numel *= dim ;
766767 }
767- if (scalarType == DType .FLOAT .jniCode ) {
768- return new Tensor_float32 (buffer .asFloatBuffer (), shape );
769- } else if (scalarType == DType .DOUBLE .jniCode ) {
770- return new Tensor_float64 (buffer .asDoubleBuffer (), shape );
771- } else if (scalarType == DType .UINT8 .jniCode ) {
768+ if (dtype == DType .UINT8 .jniCode ) {
772769 return new Tensor_uint8 (buffer , shape );
773- } else if (scalarType == DType .INT8 .jniCode ) {
770+ } else if (dtype == DType .INT8 .jniCode ) {
774771 return new Tensor_int8 (buffer , shape );
775- } else if (scalarType == DType .INT16 .jniCode ) {
772+ } else if (dtype == DType .INT32 .jniCode ) {
776773 return new Tensor_int32 (buffer .asIntBuffer (), shape );
777- } else if (scalarType == DType .INT64 .jniCode ) {
774+ } else if (dtype == DType .INT64 .jniCode ) {
778775 return new Tensor_int64 (buffer .asLongBuffer (), shape );
776+ } else if (dtype == DType .FLOAT .jniCode ) {
777+ return new Tensor_float32 (buffer .asFloatBuffer (), shape );
778+ } else if (dtype == DType .DOUBLE .jniCode ) {
779+ return new Tensor_float64 (buffer .asDoubleBuffer (), shape );
779780 } else {
780781 throw new IllegalArgumentException ("Unknown Tensor dtype" );
781782 }
0 commit comments