@@ -630,6 +630,44 @@ public String toString() {
630
630
}
631
631
}
632
632
633
+ /**
634
+ * A generic holder for 16b (fp16/bfloat16) tensor. Not intended for general use. User can only
635
+ * use #Tensor.getDataAsUnsignedByteArray() to get raw bytes. Users need to parse it.
636
+ */
637
+ static class Tensor_raw_data_16b extends Tensor {
638
+ private final ByteBuffer data ;
639
+ private final DType myDtype ;
640
+
641
+ private Tensor_raw_data_16b (ByteBuffer data , long [] shape , DType dtype ) {
642
+ super (shape );
643
+ this .data = data ;
644
+ this .myDtype = dtype ;
645
+ }
646
+
647
+ @ Override
648
+ public DType dtype () {
649
+ return myDtype ;
650
+ }
651
+
652
+ @ Override
653
+ Buffer getRawDataBuffer () {
654
+ return data ;
655
+ }
656
+
657
+ @ Override
658
+ public byte [] getDataAsUnsignedByteArray () {
659
+ data .rewind ();
660
+ byte [] arr = new byte [data .remaining ()];
661
+ data .get (arr );
662
+ return arr ;
663
+ }
664
+
665
+ @ Override
666
+ public String toString () {
667
+ return String .format ("Tensor(%s, dtype=%d)" , Arrays .toString (shape ), this .myDtype );
668
+ }
669
+ }
670
+
633
671
// region checks
634
672
private static void checkArgument (boolean expression , String errorMessage , Object ... args ) {
635
673
if (!expression ) {
@@ -674,6 +712,10 @@ private static Tensor nativeNewTensor(
674
712
tensor = new Tensor_uint8 (data , shape );
675
713
} else if (DType .INT8 .jniCode == dtype ) {
676
714
tensor = new Tensor_int8 (data , shape );
715
+ } else if (DType .HALF .jniCode == dtype ) {
716
+ tensor = new Tensor_raw_data_16b (data , shape , DType .HALF );
717
+ } else if (DType .HALF .jniCode == dtype ) {
718
+ tensor = new Tensor_raw_data_16b (data , shape , DType .BFLOAT16 );
677
719
} else {
678
720
throw new IllegalArgumentException ("Unknown Tensor dtype" );
679
721
}
0 commit comments