Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

package org.pytorch.executorch;

import com.facebook.jni.annotations.DoNotStrip;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Locale;
Expand All @@ -33,7 +32,6 @@
* <p>Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
@DoNotStrip
public class EValue {
private static final int TYPE_CODE_NONE = 0;

Expand All @@ -47,115 +45,113 @@ public class EValue {
"None", "Tensor", "String", "Double", "Int", "Bool",
};

@DoNotStrip private final int mTypeCode;
@DoNotStrip private Object mData;
final int mTypeCode;
Object mData;

@DoNotStrip
private EValue(int typeCode) {
this.mTypeCode = typeCode;
}

@DoNotStrip
public boolean isNone() {
return TYPE_CODE_NONE == this.mTypeCode;
}

@DoNotStrip

public boolean isTensor() {
return TYPE_CODE_TENSOR == this.mTypeCode;
}

@DoNotStrip

public boolean isBool() {
return TYPE_CODE_BOOL == this.mTypeCode;
}

@DoNotStrip

public boolean isInt() {
return TYPE_CODE_INT == this.mTypeCode;
}

@DoNotStrip

public boolean isDouble() {
return TYPE_CODE_DOUBLE == this.mTypeCode;
}

@DoNotStrip

public boolean isString() {
return TYPE_CODE_STRING == this.mTypeCode;
}

/** Creates a new {@code EValue} of type {@code Optional} that contains no value. */
@DoNotStrip

public static EValue optionalNone() {
return new EValue(TYPE_CODE_NONE);
}

/** Creates a new {@code EValue} of type {@code Tensor}. */
@DoNotStrip

public static EValue from(Tensor tensor) {
final EValue iv = new EValue(TYPE_CODE_TENSOR);
iv.mData = tensor;
return iv;
}

/** Creates a new {@code EValue} of type {@code bool}. */
@DoNotStrip

public static EValue from(boolean value) {
final EValue iv = new EValue(TYPE_CODE_BOOL);
iv.mData = value;
return iv;
}

/** Creates a new {@code EValue} of type {@code int}. */
@DoNotStrip

public static EValue from(long value) {
final EValue iv = new EValue(TYPE_CODE_INT);
iv.mData = value;
return iv;
}

/** Creates a new {@code EValue} of type {@code double}. */
@DoNotStrip

public static EValue from(double value) {
final EValue iv = new EValue(TYPE_CODE_DOUBLE);
iv.mData = value;
return iv;
}

/** Creates a new {@code EValue} of type {@code str}. */
@DoNotStrip

public static EValue from(String value) {
final EValue iv = new EValue(TYPE_CODE_STRING);
iv.mData = value;
return iv;
}

@DoNotStrip

public Tensor toTensor() {
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
return (Tensor) mData;
}

@DoNotStrip

public boolean toBool() {
preconditionType(TYPE_CODE_BOOL, mTypeCode);
return (boolean) mData;
}

@DoNotStrip

public long toInt() {
preconditionType(TYPE_CODE_INT, mTypeCode);
return (long) mData;
}

@DoNotStrip

public double toDouble() {
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
return (double) mData;
}

@DoNotStrip

public String toStr() {
preconditionType(TYPE_CODE_STRING, mTypeCode);
return (String) mData;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

package org.pytorch.executorch;

import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;

Expand All @@ -33,10 +32,16 @@ public static ExecuTorchRuntime getRuntime() {
}

/** Get all registered ops. */
@DoNotStrip
public static native String[] getRegisteredOps();
public static String[] getRegisteredOps() {
return nativeGetRegisteredOps();
}

private static native String[] nativeGetRegisteredOps();

/** Get all registered backends. */
@DoNotStrip
public static native String[] getRegisteredBackends();
public static String[] getRegisteredBackends() {
return nativeGetRegisteredBackends();
}

private static native String[] nativeGetRegisteredBackends();
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
package org.pytorch.executorch;

import android.util.Log;
import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.File;
Expand Down Expand Up @@ -48,18 +46,18 @@ public class Module {
/** Load mode for the module. Use memory locking and ignore errors. */
public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;

private final HybridData mHybridData;
private long mNativeHandle;

private final Map<String, MethodMetadata> mMethodMetadata;

@DoNotStrip
private static native HybridData initHybrid(
String moduleAbsolutePath, int loadMode, int initHybrid);
private static native long nativeCreate(String moduleAbsolutePath, int loadMode, int numThreads);

private static native void nativeDestroy(long nativeHandle);

private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();

mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
mNativeHandle = nativeCreate(moduleAbsolutePath, loadMode, numThreads);

mMethodMetadata = populateMethodMeta();
}
Expand All @@ -75,7 +73,7 @@ Map<String, MethodMetadata> populateMethodMeta() {
return metadata;
}

/** Lock protecting the non-thread safe methods in mHybridData. */
/** Lock protecting the non-thread safe methods in native handle. */
private Lock mLock = new ReentrantLock();

/**
Expand Down Expand Up @@ -138,18 +136,18 @@ public EValue[] forward(EValue... inputs) {
public EValue[] execute(String methodName, EValue... inputs) {
try {
mLock.lock();
if (!mHybridData.isValid()) {
if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
return executeNative(methodName, inputs);
return nativeExecute(mNativeHandle, methodName, inputs);
} finally {
mLock.unlock();
}
}

@DoNotStrip
private native EValue[] executeNative(String methodName, EValue... inputs);
private static native EValue[] nativeExecute(
long nativeHandle, String methodName, EValue... inputs);

/**
* Load a method on this module. This might help with the first time inference performance,
Expand All @@ -163,35 +161,40 @@ public EValue[] execute(String methodName, EValue... inputs) {
public int loadMethod(String methodName) {
try {
mLock.lock();
if (!mHybridData.isValid()) {
if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return 0x2; // InvalidState
}
return loadMethodNative(methodName);
return nativeLoadMethod(mNativeHandle, methodName);
} finally {
mLock.unlock();
}
}

@DoNotStrip
private native int loadMethodNative(String methodName);
private static native int nativeLoadMethod(long nativeHandle, String methodName);

/**
* Returns the names of the backends in a certain method.
*
* @param methodName method name to query
* @return an array of backend name
*/
@DoNotStrip
private native String[] getUsedBackends(String methodName);
public String[] getUsedBackends(String methodName) {
return nativeGetUsedBackends(mNativeHandle, methodName);
}

private static native String[] nativeGetUsedBackends(long nativeHandle, String methodName);

/**
* Returns the names of methods.
*
* @return name of methods in this Module
*/
@DoNotStrip
public native String[] getMethods();
public String[] getMethods() {
return nativeGetMethods(mNativeHandle);
}

private static native String[] nativeGetMethods(long nativeHandle);

/**
* Get the corresponding @MethodMetadata for a method
Expand All @@ -211,20 +214,18 @@ public MethodMetadata getMethodMetadata(String name) {
return methodMetadata;
}

@DoNotStrip
private static native String[] readLogBufferStaticNative();
private static native String[] nativeReadLogBufferStatic();

public static String[] readLogBufferStatic() {
return readLogBufferStaticNative();
return nativeReadLogBufferStatic();
}

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
public String[] readLogBuffer() {
return readLogBufferNative();
return nativeReadLogBuffer(mNativeHandle);
}

@DoNotStrip
private native String[] readLogBufferNative();
private static native String[] nativeReadLogBuffer(long nativeHandle);

/**
* Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump.
Expand All @@ -234,19 +235,25 @@ public String[] readLogBuffer() {
* @return true if the etdump was successfully written, false otherwise.
*/
@Experimental
@DoNotStrip
public native boolean etdump();
public boolean etdump() {
return nativeEtdump(mNativeHandle);
}

private static native boolean nativeEtdump(long nativeHandle);

/**
* Explicitly destroys the native Module object. Calling this method is not required, as the
* native object will be destroyed when this object is garbage-collected. However, the timing of
* garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
* more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
* more quickly.
*/
public void destroy() {
if (mLock.tryLock()) {
try {
mHybridData.resetNative();
if (mNativeHandle != 0) {
nativeDestroy(mNativeHandle);
mNativeHandle = 0;
}
} finally {
mLock.unlock();
}
Expand All @@ -257,4 +264,13 @@ public void destroy() {
+ " released.");
}
}

@Override
protected void finalize() throws Throwable {
if (mNativeHandle != 0) {
nativeDestroy(mNativeHandle);
mNativeHandle = 0;
}
super.finalize();
}
}
Loading