File tree Expand file tree Collapse file tree 2 files changed +36
-1
lines changed
extension/android/executorch_android/src/main/java/org/pytorch/executorch Expand file tree Collapse file tree 2 files changed +36
-1
lines changed Original file line number Diff line number Diff line change @@ -73,4 +73,13 @@ public enum DType {
7373 DType (int jniCode ) {
7474 this .jniCode = jniCode ;
7575 }
76+
77+ public static DType fromJniCode (int jniCode ) {
78+ for (DType dtype : values ()) {
79+ if (dtype .jniCode == jniCode ) {
80+ return dtype ;
81+ }
82+ }
83+ throw new IllegalArgumentException ("No DType found for jniCode " + jniCode );
84+ }
7685}
Original file line number Diff line number Diff line change 88
99package org .pytorch .executorch ;
1010
11+ import android .util .Log ;
1112import com .facebook .jni .HybridData ;
1213import com .facebook .jni .annotations .DoNotStrip ;
1314import java .nio .Buffer ;
@@ -630,6 +631,31 @@ public String toString() {
630631 }
631632 }
632633
634+ static class Tensor_unsupported extends Tensor {
635+ private final ByteBuffer data ;
636+ private final DType myDtype ;
637+
638+ private Tensor_unsupported (ByteBuffer data , long [] shape , DType dtype ) {
639+ super (shape );
640+ this .data = data ;
641+ this .myDtype = dtype ;
642+ Log .e (
643+ "ExecuTorch" ,
644+ toString () + " in Java. Please consider re-export the model with proper return type" );
645+ }
646+
647+ @ Override
648+ public DType dtype () {
649+ return myDtype ;
650+ }
651+
652+ @ Override
653+ public String toString () {
654+ return String .format (
655+ "Unsupported tensor(%s, dtype=%d)" , Arrays .toString (shape ), this .myDtype );
656+ }
657+ }
658+
633659 // region checks
634660 private static void checkArgument (boolean expression , String errorMessage , Object ... args ) {
635661 if (!expression ) {
@@ -675,7 +701,7 @@ private static Tensor nativeNewTensor(
675701 } else if (DType .INT8 .jniCode == dtype ) {
676702 tensor = new Tensor_int8 (data , shape );
677703 } else {
678- throw new IllegalArgumentException ( "Unknown Tensor dtype" );
704+ tensor = new Tensor_unsupported ( data , shape , DType . fromJniCode ( dtype ) );
679705 }
680706 tensor .mHybridData = hybridData ;
681707 return tensor ;
You can’t perform that action at this time.
0 commit comments