Skip to content

Commit 7e2106c

Browse files
committed
[Android] Support 16bit data as raw data byte[]
In java, when the returned dtype is fp16 or bf16, instead of crash, use byte[] to represent these raw data, and let user parse the byte[]
1 parent 98c2c53 commit 7e2106c

File tree

1 file changed

+42
-0
lines changed
  • extension/android/executorch_android/src/main/java/org/pytorch/executorch

1 file changed

+42
-0
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,44 @@ public String toString() {
630630
}
631631
}
632632

633+
/**
634+
* A generic holder for 16b (fp16/bfloat16) tensor. Not intended for general use. User can only
635+
* use #Tensor.getDataAsUnsignedByteArray() to get raw bytes. Users need to parse it.
636+
*/
637+
static class Tensor_raw_data_16b extends Tensor {
638+
private final ByteBuffer data;
639+
private final DType myDtype;
640+
641+
private Tensor_raw_data_16b(ByteBuffer data, long[] shape, DType dtype) {
642+
super(shape);
643+
this.data = data;
644+
this.myDtype = dtype;
645+
}
646+
647+
@Override
648+
public DType dtype() {
649+
return myDtype;
650+
}
651+
652+
@Override
653+
Buffer getRawDataBuffer() {
654+
return data;
655+
}
656+
657+
@Override
658+
public byte[] getDataAsUnsignedByteArray() {
659+
data.rewind();
660+
byte[] arr = new byte[data.remaining()];
661+
data.get(arr);
662+
return arr;
663+
}
664+
665+
@Override
666+
public String toString() {
667+
return String.format("Tensor(%s, dtype=%d)", Arrays.toString(shape), this.myDtype);
668+
}
669+
}
670+
633671
// region checks
634672
private static void checkArgument(boolean expression, String errorMessage, Object... args) {
635673
if (!expression) {
@@ -674,6 +712,10 @@ private static Tensor nativeNewTensor(
674712
tensor = new Tensor_uint8(data, shape);
675713
} else if (DType.INT8.jniCode == dtype) {
676714
tensor = new Tensor_int8(data, shape);
715+
} else if (DType.HALF.jniCode == dtype) {
716+
tensor = new Tensor_raw_data_16b(data, shape, DType.HALF);
717+
} else if (DType.HALF.jniCode == dtype) {
718+
tensor = new Tensor_raw_data_16b(data, shape, DType.BFLOAT16);
677719
} else {
678720
throw new IllegalArgumentException("Unknown Tensor dtype");
679721
}

0 commit comments

Comments
 (0)