diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index 6b32d90cda8..5893fc56658 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -630,6 +630,44 @@ public String toString() { } } + /** + * A generic holder for 16b (fp16/bfloat16) tensor. Not intended for general use. User can only + * use #Tensor.getDataAsUnsignedByteArray() to get raw bytes. Users need to parse it. + */ + static class Tensor_raw_data_16b extends Tensor { + private final ByteBuffer data; + private final DType myDtype; + + private Tensor_raw_data_16b(ByteBuffer data, long[] shape, DType dtype) { + super(shape); + this.data = data; + this.myDtype = dtype; + } + + @Override + public DType dtype() { + return myDtype; + } + + @Override + Buffer getRawDataBuffer() { + return data; + } + + @Override + public byte[] getDataAsUnsignedByteArray() { + data.rewind(); + byte[] arr = new byte[data.remaining()]; + data.get(arr); + return arr; + } + + @Override + public String toString() { + return String.format("Tensor(%s, dtype=%d)", Arrays.toString(shape), this.myDtype); + } + } + // region checks private static void checkArgument(boolean expression, String errorMessage, Object... args) { if (!expression) { @@ -674,6 +712,10 @@ private static Tensor nativeNewTensor( tensor = new Tensor_uint8(data, shape); } else if (DType.INT8.jniCode == dtype) { tensor = new Tensor_int8(data, shape); + } else if (DType.HALF.jniCode == dtype) { + tensor = new Tensor_raw_data_16b(data, shape, DType.HALF); + } else if (DType.BFLOAT16.jniCode == dtype) { + tensor = new Tensor_raw_data_16b(data, shape, DType.BFLOAT16); } else { throw new IllegalArgumentException("Unknown Tensor dtype"); }