diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java
index ab3b77ff1fb..e0122e3979e 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java
@@ -8,7 +8,6 @@
package org.pytorch.executorch;
-import com.facebook.jni.annotations.DoNotStrip;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Locale;
@@ -33,7 +32,6 @@
*
Warning: These APIs are experimental and subject to change without notice
*/
@Experimental
-@DoNotStrip
public class EValue {
private static final int TYPE_CODE_NONE = 0;
@@ -47,52 +45,50 @@ public class EValue {
"None", "Tensor", "String", "Double", "Int", "Bool",
};
- @DoNotStrip private final int mTypeCode;
- @DoNotStrip private Object mData;
+ final int mTypeCode;
+ Object mData;
- @DoNotStrip
private EValue(int typeCode) {
this.mTypeCode = typeCode;
}
- @DoNotStrip
public boolean isNone() {
return TYPE_CODE_NONE == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isTensor() {
return TYPE_CODE_TENSOR == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isBool() {
return TYPE_CODE_BOOL == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isInt() {
return TYPE_CODE_INT == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isDouble() {
return TYPE_CODE_DOUBLE == this.mTypeCode;
}
- @DoNotStrip
+
public boolean isString() {
return TYPE_CODE_STRING == this.mTypeCode;
}
/** Creates a new {@code EValue} of type {@code Optional} that contains no value. */
- @DoNotStrip
+
public static EValue optionalNone() {
return new EValue(TYPE_CODE_NONE);
}
/** Creates a new {@code EValue} of type {@code Tensor}. */
- @DoNotStrip
+
public static EValue from(Tensor tensor) {
final EValue iv = new EValue(TYPE_CODE_TENSOR);
iv.mData = tensor;
@@ -100,7 +96,7 @@ public static EValue from(Tensor tensor) {
}
/** Creates a new {@code EValue} of type {@code bool}. */
- @DoNotStrip
+
public static EValue from(boolean value) {
final EValue iv = new EValue(TYPE_CODE_BOOL);
iv.mData = value;
@@ -108,7 +104,7 @@ public static EValue from(boolean value) {
}
/** Creates a new {@code EValue} of type {@code int}. */
- @DoNotStrip
+
public static EValue from(long value) {
final EValue iv = new EValue(TYPE_CODE_INT);
iv.mData = value;
@@ -116,7 +112,7 @@ public static EValue from(long value) {
}
/** Creates a new {@code EValue} of type {@code double}. */
- @DoNotStrip
+
public static EValue from(double value) {
final EValue iv = new EValue(TYPE_CODE_DOUBLE);
iv.mData = value;
@@ -124,38 +120,38 @@ public static EValue from(double value) {
}
/** Creates a new {@code EValue} of type {@code str}. */
- @DoNotStrip
+
public static EValue from(String value) {
final EValue iv = new EValue(TYPE_CODE_STRING);
iv.mData = value;
return iv;
}
- @DoNotStrip
+
public Tensor toTensor() {
preconditionType(TYPE_CODE_TENSOR, mTypeCode);
return (Tensor) mData;
}
- @DoNotStrip
+
public boolean toBool() {
preconditionType(TYPE_CODE_BOOL, mTypeCode);
return (boolean) mData;
}
- @DoNotStrip
+
public long toInt() {
preconditionType(TYPE_CODE_INT, mTypeCode);
return (long) mData;
}
- @DoNotStrip
+
public double toDouble() {
preconditionType(TYPE_CODE_DOUBLE, mTypeCode);
return (double) mData;
}
- @DoNotStrip
+
public String toStr() {
preconditionType(TYPE_CODE_STRING, mTypeCode);
return (String) mData;
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java
index 8e2f259ef3a..dfa9f77b6dd 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java
@@ -8,7 +8,6 @@
package org.pytorch.executorch;
-import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
@@ -33,10 +32,16 @@ public static ExecuTorchRuntime getRuntime() {
}
/** Get all registered ops. */
- @DoNotStrip
- public static native String[] getRegisteredOps();
+ public static String[] getRegisteredOps() {
+ return nativeGetRegisteredOps();
+ }
+
+ private static native String[] nativeGetRegisteredOps();
/** Get all registered backends. */
- @DoNotStrip
- public static native String[] getRegisteredBackends();
+ public static String[] getRegisteredBackends() {
+ return nativeGetRegisteredBackends();
+ }
+
+ private static native String[] nativeGetRegisteredBackends();
}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java
index 6da76bf4b74..481165f4e21 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java
@@ -9,8 +9,6 @@
package org.pytorch.executorch;
import android.util.Log;
-import com.facebook.jni.HybridData;
-import com.facebook.jni.annotations.DoNotStrip;
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.File;
@@ -48,18 +46,18 @@ public class Module {
/** Load mode for the module. Use memory locking and ignore errors. */
public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3;
- private final HybridData mHybridData;
+ private long mNativeHandle;
private final Map mMethodMetadata;
- @DoNotStrip
- private static native HybridData initHybrid(
- String moduleAbsolutePath, int loadMode, int initHybrid);
+ private static native long nativeCreate(String moduleAbsolutePath, int loadMode, int numThreads);
+
+ private static native void nativeDestroy(long nativeHandle);
private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime();
- mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);
+ mNativeHandle = nativeCreate(moduleAbsolutePath, loadMode, numThreads);
mMethodMetadata = populateMethodMeta();
}
@@ -75,7 +73,7 @@ Map populateMethodMeta() {
return metadata;
}
- /** Lock protecting the non-thread safe methods in mHybridData. */
+ /** Lock protecting the non-thread safe methods in native handle. */
private Lock mLock = new ReentrantLock();
/**
@@ -138,18 +136,18 @@ public EValue[] forward(EValue... inputs) {
public EValue[] execute(String methodName, EValue... inputs) {
try {
mLock.lock();
- if (!mHybridData.isValid()) {
+ if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return new EValue[0];
}
- return executeNative(methodName, inputs);
+ return nativeExecute(mNativeHandle, methodName, inputs);
} finally {
mLock.unlock();
}
}
- @DoNotStrip
- private native EValue[] executeNative(String methodName, EValue... inputs);
+ private static native EValue[] nativeExecute(
+ long nativeHandle, String methodName, EValue... inputs);
/**
* Load a method on this module. This might help with the first time inference performance,
@@ -163,18 +161,17 @@ public EValue[] execute(String methodName, EValue... inputs) {
public int loadMethod(String methodName) {
try {
mLock.lock();
- if (!mHybridData.isValid()) {
+ if (mNativeHandle == 0) {
Log.e("ExecuTorch", "Attempt to use a destroyed module");
return 0x2; // InvalidState
}
- return loadMethodNative(methodName);
+ return nativeLoadMethod(mNativeHandle, methodName);
} finally {
mLock.unlock();
}
}
- @DoNotStrip
- private native int loadMethodNative(String methodName);
+ private static native int nativeLoadMethod(long nativeHandle, String methodName);
/**
* Returns the names of the backends in a certain method.
@@ -182,16 +179,22 @@ public int loadMethod(String methodName) {
* @param methodName method name to query
* @return an array of backend name
*/
- @DoNotStrip
- private native String[] getUsedBackends(String methodName);
+ public String[] getUsedBackends(String methodName) {
+ return nativeGetUsedBackends(mNativeHandle, methodName);
+ }
+
+ private static native String[] nativeGetUsedBackends(long nativeHandle, String methodName);
/**
* Returns the names of methods.
*
* @return name of methods in this Module
*/
- @DoNotStrip
- public native String[] getMethods();
+ public String[] getMethods() {
+ return nativeGetMethods(mNativeHandle);
+ }
+
+ private static native String[] nativeGetMethods(long nativeHandle);
/**
* Get the corresponding @MethodMetadata for a method
@@ -211,20 +214,18 @@ public MethodMetadata getMethodMetadata(String name) {
return methodMetadata;
}
- @DoNotStrip
- private static native String[] readLogBufferStaticNative();
+ private static native String[] nativeReadLogBufferStatic();
public static String[] readLogBufferStatic() {
- return readLogBufferStaticNative();
+ return nativeReadLogBufferStatic();
}
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
public String[] readLogBuffer() {
- return readLogBufferNative();
+ return nativeReadLogBuffer(mNativeHandle);
}
- @DoNotStrip
- private native String[] readLogBufferNative();
+ private static native String[] nativeReadLogBuffer(long nativeHandle);
/**
* Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump.
@@ -234,19 +235,25 @@ public String[] readLogBuffer() {
* @return true if the etdump was successfully written, false otherwise.
*/
@Experimental
- @DoNotStrip
- public native boolean etdump();
+ public boolean etdump() {
+ return nativeEtdump(mNativeHandle);
+ }
+
+ private static native boolean nativeEtdump(long nativeHandle);
/**
* Explicitly destroys the native Module object. Calling this method is not required, as the
* native object will be destroyed when this object is garbage-collected. However, the timing of
* garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory
- * more quickly. See {@link com.facebook.jni.HybridData#resetNative}.
+ * more quickly.
*/
public void destroy() {
if (mLock.tryLock()) {
try {
- mHybridData.resetNative();
+ if (mNativeHandle != 0) {
+ nativeDestroy(mNativeHandle);
+ mNativeHandle = 0;
+ }
} finally {
mLock.unlock();
}
@@ -257,4 +264,13 @@ public void destroy() {
+ " released.");
}
}
+
+ @Override
+ protected void finalize() throws Throwable {
+ if (mNativeHandle != 0) {
+ nativeDestroy(mNativeHandle);
+ mNativeHandle = 0;
+ }
+ super.finalize();
+ }
}
diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java
index e8c0a918b13..a103e3691c2 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java
@@ -9,8 +9,6 @@
package org.pytorch.executorch;
import android.util.Log;
-import com.facebook.jni.HybridData;
-import com.facebook.jni.annotations.DoNotStrip;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -53,7 +51,7 @@ public abstract class Tensor {
private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT =
"Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)";
- @DoNotStrip final long[] shape;
+ final long[] shape;
private static final int BYTE_SIZE_BYTES = 1;
private static final int INT_SIZE_BYTES = 4;
@@ -468,7 +466,8 @@ public static Tensor zeros(long[] shape, DType dtype) {
}
}
- @DoNotStrip private HybridData mHybridData;
+ // Native handle for tensor data (unused in pure JNI but kept for API compatibility)
+ private long mNativeHandle;
private Tensor(long[] shape) {
checkShape(shape);
@@ -501,7 +500,6 @@ public long[] shape() {
public abstract DType dtype();
// Called from native
- @DoNotStrip
int dtypeJniCode() {
return dtype().jniCode;
}
@@ -572,7 +570,6 @@ public double[] getDataAsDoubleArray() {
"Tensor of type " + getClass().getSimpleName() + " cannot return data as double array.");
}
- @DoNotStrip
Buffer getRawDataBuffer() {
throw new IllegalStateException(
"Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer.");
@@ -889,9 +886,8 @@ private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[
// endregion checks
// Called from native
- @DoNotStrip
private static Tensor nativeNewTensor(
- ByteBuffer data, long[] shape, int dtype, HybridData hybridData) {
+ ByteBuffer data, long[] shape, int dtype, long nativeHandle) {
Tensor tensor = null;
if (DType.FLOAT.jniCode == dtype) {
@@ -911,7 +907,7 @@ private static Tensor nativeNewTensor(
} else {
tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype));
}
- tensor.mHybridData = hybridData;
+ tensor.mNativeHandle = nativeHandle;
return tensor;
}
diff --git a/extension/android/jni/jni_helper.cpp b/extension/android/jni/jni_helper.cpp
index 6491524c7ac..37f9b271e52 100644
--- a/extension/android/jni/jni_helper.cpp
+++ b/extension/android/jni/jni_helper.cpp
@@ -10,6 +10,60 @@
namespace executorch::jni_helper {
+void throwExecutorchException(
+ JNIEnv* env,
+ uint32_t errorCode,
+ const std::string& details) {
+ if (!env) {
+ return;
+ }
+
+ // Find the exception class
+ jclass exceptionClass =
+ env->FindClass("org/pytorch/executorch/ExecutorchRuntimeException");
+ if (exceptionClass == nullptr) {
+ // Class not found, clear the exception and return
+ env->ExceptionClear();
+ return;
+ }
+
+ // Find the static factory method: makeExecutorchException(int, String)
+ jmethodID makeExceptionMethod = env->GetStaticMethodID(
+ exceptionClass,
+ "makeExecutorchException",
+ "(ILjava/lang/String;)Ljava/lang/RuntimeException;");
+ if (makeExceptionMethod == nullptr) {
+ env->ExceptionClear();
+ env->DeleteLocalRef(exceptionClass);
+ return;
+ }
+
+ // Create the details string
+ jstring jDetails = env->NewStringUTF(details.c_str());
+ if (jDetails == nullptr) {
+ env->ExceptionClear();
+ env->DeleteLocalRef(exceptionClass);
+ return;
+ }
+
+ // Call the factory method to create the exception object
+ jobject exception = env->CallStaticObjectMethod(
+ exceptionClass,
+ makeExceptionMethod,
+ static_cast(errorCode),
+ jDetails);
+
+ env->DeleteLocalRef(jDetails);
+
+ if (exception != nullptr) {
+ env->Throw(static_cast(exception));
+ env->DeleteLocalRef(exception);
+ }
+
+ env->DeleteLocalRef(exceptionClass);
+}
+
+#if EXECUTORCH_HAS_FBJNI
void throwExecutorchException(uint32_t errorCode, const std::string& details) {
// Get the current JNI environment
auto env = facebook::jni::Environment::current();
@@ -34,5 +88,6 @@ void throwExecutorchException(uint32_t errorCode, const std::string& details) {
auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails);
facebook::jni::throwNewJavaException(exception.get());
}
+#endif
} // namespace executorch::jni_helper
diff --git a/extension/android/jni/jni_helper.h b/extension/android/jni/jni_helper.h
index 898c1619d9c..683a3cfe447 100644
--- a/extension/android/jni/jni_helper.h
+++ b/extension/android/jni/jni_helper.h
@@ -8,9 +8,16 @@
#pragma once
-#include
+#include
#include
+#if __has_include()
+#include
+#define EXECUTORCH_HAS_FBJNI 1
+#else
+#define EXECUTORCH_HAS_FBJNI 0
+#endif
+
namespace executorch::jni_helper {
/**
@@ -18,6 +25,25 @@ namespace executorch::jni_helper {
* code and details. Uses the Java factory method
* ExecutorchRuntimeException.makeExecutorchException(int, String).
*
+ * This version takes JNIEnv* directly and works with pure JNI.
+ *
+ * @param env The JNI environment.
+ * @param errorCode The error code from the C++ Executorch runtime.
+ * @param details Additional details to include in the exception message.
+ */
+void throwExecutorchException(
+ JNIEnv* env,
+ uint32_t errorCode,
+ const std::string& details);
+
+#if EXECUTORCH_HAS_FBJNI
+/**
+ * Throws a Java ExecutorchRuntimeException corresponding to the given error
+ * code and details. Uses the Java factory method
+ * ExecutorchRuntimeException.makeExecutorchException(int, String).
+ *
+ * This version uses fbjni to get the current JNI environment.
+ *
* @param errorCode The error code from the C++ Executorch runtime.
* @param details Additional details to include in the exception message.
*/
@@ -29,5 +55,6 @@ struct JExecutorchRuntimeException
static constexpr auto kJavaDescriptor =
"Lorg/pytorch/executorch/ExecutorchRuntimeException;";
};
+#endif
} // namespace executorch::jni_helper
diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp
index 1f8457e00c5..93c6e111a4d 100644
--- a/extension/android/jni/jni_layer.cpp
+++ b/extension/android/jni/jni_layer.cpp
@@ -6,6 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/
+#include
+
#include
#include
@@ -39,223 +41,120 @@
#include
#endif
-#include
-#include
-
using namespace executorch::extension;
using namespace torch::executor;
-namespace executorch::extension {
-class TensorHybrid : public facebook::jni::HybridClass {
- public:
- constexpr static const char* kJavaDescriptor =
- "Lorg/pytorch/executorch/Tensor;";
-
- explicit TensorHybrid(executorch::aten::Tensor tensor) {}
-
- static facebook::jni::local_ref
- newJTensorFromTensor(const executorch::aten::Tensor& tensor) {
- // Java wrapper currently only supports contiguous tensors.
+// Helper to convert jstring to std::string (defined outside namespace for broad access)
+static std::string jstring_to_string(JNIEnv* env, jstring jstr) {
+ if (jstr == nullptr) {
+ return "";
+ }
+ const char* chars = env->GetStringUTFChars(jstr, nullptr);
+ if (chars == nullptr) {
+ return "";
+ }
+ std::string result(chars);
+ env->ReleaseStringUTFChars(jstr, chars);
+ return result;
+}
- const auto scalarType = tensor.scalar_type();
- int jdtype = scalar_type_to_java_dtype.at(scalarType);
- if (scalar_type_to_java_dtype.count(scalarType) == 0) {
- std::stringstream ss;
- ss << "executorch::aten::Tensor scalar [java] type: " << jdtype
- << " is not supported on java side";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
+namespace {
+
+// Global JavaVM pointer for obtaining JNIEnv in callbacks
+JavaVM* g_jvm = nullptr;
+
+// EValue type codes (must match Java EValue class)
+constexpr int kTypeCodeNone = 0;
+constexpr int kTypeCodeTensor = 1;
+constexpr int kTypeCodeString = 2;
+constexpr int kTypeCodeDouble = 3;
+constexpr int kTypeCodeInt = 4;
+constexpr int kTypeCodeBool = 5;
+
+// Cached class and method IDs for performance
+struct JniCache {
+ jclass tensor_class = nullptr;
+ jclass evalue_class = nullptr;
+ jmethodID tensor_nativeNewTensor = nullptr;
+ jmethodID tensor_dtypeJniCode = nullptr;
+ jmethodID tensor_getRawDataBuffer = nullptr;
+ jfieldID tensor_shape = nullptr;
+ jmethodID evalue_from_tensor = nullptr;
+ jmethodID evalue_from_long = nullptr;
+ jmethodID evalue_from_double = nullptr;
+ jmethodID evalue_from_bool = nullptr;
+ jmethodID evalue_from_string = nullptr;
+ jmethodID evalue_toTensor = nullptr;
+ jfieldID evalue_mTypeCode = nullptr;
+ jfieldID evalue_mData = nullptr;
+
+ bool initialized = false;
+
+ void init(JNIEnv* env) {
+ if (initialized) {
+ return;
}
- const auto& tensor_shape = tensor.sizes();
- std::vector tensor_shape_vec;
- for (const auto& s : tensor_shape) {
- tensor_shape_vec.push_back(s);
- }
- facebook::jni::local_ref jTensorShape =
- facebook::jni::make_long_array(tensor_shape_vec.size());
- jTensorShape->setRegion(
- 0, tensor_shape_vec.size(), tensor_shape_vec.data());
-
- static auto cls = TensorHybrid::javaClassStatic();
- // Note: this is safe as long as the data stored in tensor is valid; the
- // data won't go out of scope as long as the Method for the inference is
- // valid and there is no other inference call. Java layer picks up this
- // value immediately so the data is valid.
- facebook::jni::local_ref jTensorBuffer =
- facebook::jni::JByteBuffer::wrapBytes(
- (uint8_t*)tensor.data_ptr(), tensor.nbytes());
- jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
-
- static const auto jMethodNewTensor =
- cls->getStaticMethod(
- facebook::jni::alias_ref,
- facebook::jni::alias_ref,
- jint,
- facebook::jni::alias_ref)>("nativeNewTensor");
- return jMethodNewTensor(
- cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor));
- }
-
- static TensorPtr newTensorFromJTensor(
- facebook::jni::alias_ref jtensor) {
- static auto cls = TensorHybrid::javaClassStatic();
- static const auto dtypeMethod = cls->getMethod("dtypeJniCode");
- jint jdtype = dtypeMethod(jtensor);
-
- static const auto shapeField = cls->getField("shape");
- auto jshape = jtensor->getFieldValue(shapeField);
-
- static auto dataBufferMethod = cls->getMethod<
- facebook::jni::local_ref()>(
- "getRawDataBuffer");
- facebook::jni::local_ref jbuffer =
- dataBufferMethod(jtensor);
-
- const auto rank = jshape->size();
-
- const auto shapeArr = jshape->getRegion(0, rank);
- std::vector shape_vec;
- shape_vec.reserve(rank);
-
- int64_t numel = 1;
- for (int i = 0; i < rank; i++) {
- shape_vec.push_back(shapeArr[i]);
+ // Cache Tensor class and methods
+ jclass local_tensor_class = env->FindClass("org/pytorch/executorch/Tensor");
+ if (local_tensor_class != nullptr) {
+ tensor_class = static_cast(env->NewGlobalRef(local_tensor_class));
+ env->DeleteLocalRef(local_tensor_class);
+
+ tensor_nativeNewTensor = env->GetStaticMethodID(
+ tensor_class,
+ "nativeNewTensor",
+ "(Ljava/nio/ByteBuffer;[JIJ)Lorg/pytorch/executorch/Tensor;");
+ tensor_dtypeJniCode = env->GetMethodID(tensor_class, "dtypeJniCode", "()I");
+ tensor_getRawDataBuffer =
+ env->GetMethodID(tensor_class, "getRawDataBuffer", "()Ljava/nio/Buffer;");
+ tensor_shape = env->GetFieldID(tensor_class, "shape", "[J");
}
- for (int i = rank - 1; i >= 0; --i) {
- numel *= shapeArr[i];
- }
- JNIEnv* jni = facebook::jni::Environment::current();
- if (java_dtype_to_scalar_type.count(jdtype) == 0) {
- std::stringstream ss;
- ss << "Unknown Tensor jdtype: [" << jdtype << "]";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- }
- ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
- const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get());
- if (dataCapacity < 0) {
- std::stringstream ss;
- ss << "Tensor buffer is not direct or has invalid capacity";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- }
- const size_t elementSize = executorch::runtime::elementSize(scalar_type);
- const jlong expectedElements = static_cast(numel);
- const jlong expectedBytes =
- expectedElements * static_cast(elementSize);
- const bool matchesElements = dataCapacity == expectedElements;
- const bool matchesBytes = dataCapacity == expectedBytes;
- if (!matchesElements && !matchesBytes) {
- std::stringstream ss;
- ss << "Tensor dimensions(elements number: " << numel
- << ") inconsistent with buffer capacity " << dataCapacity
- << " (element size bytes: " << elementSize << ")";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- }
- return from_blob(
- jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type);
- }
-
- private:
- friend HybridBase;
-};
-class JEValue : public facebook::jni::JavaClass {
- public:
- constexpr static const char* kJavaDescriptor =
- "Lorg/pytorch/executorch/EValue;";
-
- constexpr static int kTypeCodeTensor = 1;
- constexpr static int kTypeCodeString = 2;
- constexpr static int kTypeCodeDouble = 3;
- constexpr static int kTypeCodeInt = 4;
- constexpr static int kTypeCodeBool = 5;
-
- static facebook::jni::local_ref newJEValueFromEValue(EValue evalue) {
- if (evalue.isTensor()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(
- facebook::jni::local_ref)>("from");
- return jMethodTensor(
- JEValue::javaClassStatic(),
- TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
- } else if (evalue.isInt()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(jlong)>(
- "from");
- return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt());
- } else if (evalue.isDouble()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(jdouble)>(
- "from");
- return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble());
- } else if (evalue.isBool()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(jboolean)>(
- "from");
- return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool());
- } else if (evalue.isString()) {
- static auto jMethodTensor =
- JEValue::javaClassStatic()
- ->getStaticMethod(
- facebook::jni::local_ref)>("from");
- std::string str =
- std::string(evalue.toString().begin(), evalue.toString().end());
- return jMethodTensor(
- JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
- }
- std::stringstream ss;
- ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]";
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- return {};
- }
-
- static TensorPtr JEValueToTensorImpl(
- facebook::jni::alias_ref JEValue) {
- static const auto typeCodeField =
- JEValue::javaClassStatic()->getField("mTypeCode");
- const auto typeCode = JEValue->getFieldValue(typeCodeField);
- if (JEValue::kTypeCodeTensor == typeCode) {
- static const auto jMethodGetTensor =
- JEValue::javaClassStatic()
- ->getMethod()>(
- "toTensor");
- auto jtensor = jMethodGetTensor(JEValue);
- return TensorHybrid::newTensorFromJTensor(jtensor);
+ // Cache EValue class and methods
+ jclass local_evalue_class = env->FindClass("org/pytorch/executorch/EValue");
+ if (local_evalue_class != nullptr) {
+ evalue_class = static_cast(env->NewGlobalRef(local_evalue_class));
+ env->DeleteLocalRef(local_evalue_class);
+
+ evalue_from_tensor = env->GetStaticMethodID(
+ evalue_class,
+ "from",
+ "(Lorg/pytorch/executorch/Tensor;)Lorg/pytorch/executorch/EValue;");
+ evalue_from_long =
+ env->GetStaticMethodID(evalue_class, "from", "(J)Lorg/pytorch/executorch/EValue;");
+ evalue_from_double =
+ env->GetStaticMethodID(evalue_class, "from", "(D)Lorg/pytorch/executorch/EValue;");
+ evalue_from_bool =
+ env->GetStaticMethodID(evalue_class, "from", "(Z)Lorg/pytorch/executorch/EValue;");
+ evalue_from_string = env->GetStaticMethodID(
+ evalue_class,
+ "from",
+ "(Ljava/lang/String;)Lorg/pytorch/executorch/EValue;");
+ evalue_toTensor = env->GetMethodID(
+ evalue_class, "toTensor", "()Lorg/pytorch/executorch/Tensor;");
+ evalue_mTypeCode = env->GetFieldID(evalue_class, "mTypeCode", "I");
+ evalue_mData = env->GetFieldID(evalue_class, "mData", "Ljava/lang/Object;");
}
- std::stringstream ss;
- ss << "Unknown EValue typeCode: " << typeCode;
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str().c_str());
- return {};
+
+ initialized = true;
}
};
-class ExecuTorchJni : public facebook::jni::HybridClass {
- private:
- friend HybridBase;
- std::unique_ptr module_;
+JniCache g_jni_cache;
- public:
- constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;";
+} // anonymous namespace
- static facebook::jni::local_ref initHybrid(
- facebook::jni::alias_ref,
- facebook::jni::alias_ref modelPath,
- jint loadMode,
- jint numThreads) {
- return makeCxxInstance(modelPath, loadMode, numThreads);
- }
+namespace executorch::extension {
+
+// Native module handle class - named ExecuTorchJni to match friend declaration in Module
+class ExecuTorchJni {
+ public:
+ std::unique_ptr module_;
ExecuTorchJni(
- facebook::jni::alias_ref modelPath,
+ JNIEnv* env,
+ jstring modelPath,
jint loadMode,
jint numThreads) {
Module::LoadMode load_mode = Module::LoadMode::Mmap;
@@ -273,17 +172,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass {
#else
auto etdump_gen = nullptr;
#endif
- module_ = std::make_unique(
- modelPath->toStdString(), load_mode, std::move(etdump_gen));
+ std::string path = jstring_to_string(env, modelPath);
+ module_ = std::make_unique(path, load_mode, std::move(etdump_gen));
#ifdef ET_USE_THREADPOOL
- // Default to using cores/2 threadpool threads. The long-term plan is to
- // improve performant core detection in CPUInfo, but for now we can use
- // cores/2 as a sane default.
- //
- // Based on testing, this is almost universally faster than using all
- // cores, as efficiency cores can be quite slow. In extreme cases, using
- // all cores can be 10x slower than using cores/2.
auto threadpool = executorch::extension::threadpool::get_threadpool();
if (threadpool) {
int thread_count =
@@ -295,244 +187,534 @@ class ExecuTorchJni : public facebook::jni::HybridClass {
#endif
}
- facebook::jni::local_ref> execute(
- facebook::jni::alias_ref methodName,
- facebook::jni::alias_ref<
- facebook::jni::JArrayClass::javaobject>
- jinputs) {
- return execute_method(methodName->toStdString(), jinputs);
+ // Access protected methods_ member (friend class privilege)
+ Method* get_method(const std::string& method_name) {
+ auto it = module_->methods_.find(method_name);
+ if (it != module_->methods_.end()) {
+ return it->second.method.get();
+ }
+ return nullptr;
}
+};
+
+} // namespace executorch::extension
+
+namespace {
+
+// Helper to create Java Tensor from native tensor
+jobject newJTensorFromTensor(JNIEnv* env, const executorch::aten::Tensor& tensor) {
+ g_jni_cache.init(env);
- jint load_method(facebook::jni::alias_ref methodName) {
- return static_cast(module_->load_method(methodName->toStdString()));
+ const auto scalarType = tensor.scalar_type();
+ if (scalar_type_to_java_dtype.count(scalarType) == 0) {
+ std::stringstream ss;
+ ss << "executorch::aten::Tensor scalar type is not supported on java side";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
}
+ int jdtype = scalar_type_to_java_dtype.at(scalarType);
- facebook::jni::local_ref> execute_method(
- std::string method,
- facebook::jni::alias_ref<
- facebook::jni::JArrayClass::javaobject>
- jinputs) {
- // If no inputs is given, it will run with sample inputs (ones)
- if (jinputs->size() == 0) {
- auto result = module_->load_method(method);
- if (result != Error::Ok) {
- // Format hex string
- std::stringstream ss;
- ss << "Cannot get method names [Native Error: 0x" << std::hex
- << std::uppercase << static_cast(result) << "]";
+ // Create shape array
+ const auto& tensor_shape = tensor.sizes();
+ jlongArray jTensorShape = env->NewLongArray(tensor_shape.size());
+ if (jTensorShape == nullptr) {
+ return nullptr;
+ }
+ std::vector shape_vec;
+ for (const auto& s : tensor_shape) {
+ shape_vec.push_back(s);
+ }
+ env->SetLongArrayRegion(jTensorShape, 0, shape_vec.size(), shape_vec.data());
+
+ // Create ByteBuffer wrapping tensor data
+ jobject jTensorBuffer = env->NewDirectByteBuffer(
+ const_cast(tensor.const_data_ptr()), tensor.nbytes());
+ if (jTensorBuffer == nullptr) {
+ env->DeleteLocalRef(jTensorShape);
+ return nullptr;
+ }
- jni_helper::throwExecutorchException(
- static_cast(result), ss.str());
- return {};
- }
- auto&& underlying_method = module_->methods_[method].method;
- auto&& buf = prepare_input_tensors(*underlying_method);
- result = underlying_method->execute();
- if (result != Error::Ok) {
- jni_helper::throwExecutorchException(
- static_cast(result),
- "Execution failed for method: " + method);
- return {};
- }
- facebook::jni::local_ref> jresult =
- facebook::jni::JArrayClass::newArray(
- underlying_method->outputs_size());
-
- for (int i = 0; i < underlying_method->outputs_size(); i++) {
- auto jevalue =
- JEValue::newJEValueFromEValue(underlying_method->get_output(i));
- jresult->setElement(i, *jevalue);
- }
- return jresult;
- }
+ // Set byte order to native order
+ jclass byteBufferClass = env->FindClass("java/nio/ByteBuffer");
+ jmethodID orderMethod =
+ env->GetMethodID(byteBufferClass, "order", "(Ljava/nio/ByteOrder;)Ljava/nio/ByteBuffer;");
+ jclass byteOrderClass = env->FindClass("java/nio/ByteOrder");
+ jmethodID nativeOrderMethod =
+ env->GetStaticMethodID(byteOrderClass, "nativeOrder", "()Ljava/nio/ByteOrder;");
+ jobject nativeOrder = env->CallStaticObjectMethod(byteOrderClass, nativeOrderMethod);
+ env->CallObjectMethod(jTensorBuffer, orderMethod, nativeOrder);
+
+ env->DeleteLocalRef(byteBufferClass);
+ env->DeleteLocalRef(byteOrderClass);
+ env->DeleteLocalRef(nativeOrder);
+
+ // Call nativeNewTensor static method (pass 0 for nativeHandle since we don't need it)
+ jobject result = env->CallStaticObjectMethod(
+ g_jni_cache.tensor_class,
+ g_jni_cache.tensor_nativeNewTensor,
+ jTensorBuffer,
+ jTensorShape,
+ jdtype,
+ static_cast(0));
+
+ env->DeleteLocalRef(jTensorBuffer);
+ env->DeleteLocalRef(jTensorShape);
+
+ return result;
+}
- std::vector evalues;
- std::vector tensors;
-
- static const auto typeCodeField =
- JEValue::javaClassStatic()->getField("mTypeCode");
-
- for (int i = 0; i < jinputs->size(); i++) {
- auto jevalue = jinputs->getElement(i);
- const auto typeCode = jevalue->getFieldValue(typeCodeField);
- if (typeCode == JEValue::kTypeCodeTensor) {
- tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue));
- evalues.emplace_back(tensors.back());
- } else if (typeCode == JEValue::kTypeCodeInt) {
- int64_t value = jevalue->getFieldValue(typeCodeField);
- evalues.emplace_back(value);
- } else if (typeCode == JEValue::kTypeCodeDouble) {
- double value = jevalue->getFieldValue(typeCodeField);
- evalues.emplace_back(value);
- } else if (typeCode == JEValue::kTypeCodeBool) {
- bool value = jevalue->getFieldValue(typeCodeField);
- evalues.emplace_back(value);
- }
- }
+// Helper to create native TensorPtr from Java Tensor
+TensorPtr newTensorFromJTensor(JNIEnv* env, jobject jtensor) {
+ g_jni_cache.init(env);
-#ifdef EXECUTORCH_ANDROID_PROFILING
- auto start = std::chrono::high_resolution_clock::now();
- auto result = module_->execute(method, evalues);
- auto end = std::chrono::high_resolution_clock::now();
- auto duration =
- std::chrono::duration_cast(end - start)
- .count();
- ET_LOG(Debug, "Execution time: %lld ms.", duration);
+ jint jdtype = env->CallIntMethod(jtensor, g_jni_cache.tensor_dtypeJniCode);
-#else
- auto result = module_->execute(method, evalues);
+ jlongArray jshape =
+ static_cast(env->GetObjectField(jtensor, g_jni_cache.tensor_shape));
-#endif
+ jobject jbuffer = env->CallObjectMethod(jtensor, g_jni_cache.tensor_getRawDataBuffer);
- if (!result.ok()) {
- jni_helper::throwExecutorchException(
- static_cast(result.error()),
- "Execution failed for method: " + method);
- return {};
- }
+ jsize rank = env->GetArrayLength(jshape);
+
+ std::vector shapeArr(rank);
+ env->GetLongArrayRegion(jshape, 0, rank, shapeArr.data());
+
+ std::vector shape_vec;
+ shape_vec.reserve(rank);
+
+ int64_t numel = 1;
+ for (int i = 0; i < rank; i++) {
+ shape_vec.push_back(shapeArr[i]);
+ }
+ for (int i = rank - 1; i >= 0; --i) {
+ numel *= shapeArr[i];
+ }
- facebook::jni::local_ref> jresult =
- facebook::jni::JArrayClass::newArray(result.get().size());
+ if (java_dtype_to_scalar_type.count(jdtype) == 0) {
+ std::stringstream ss;
+ ss << "Unknown Tensor jdtype: [" << jdtype << "]";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ env->DeleteLocalRef(jshape);
+ env->DeleteLocalRef(jbuffer);
+ return nullptr;
+ }
+
+ ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype);
+ const jlong dataCapacity = env->GetDirectBufferCapacity(jbuffer);
+ if (dataCapacity < 0) {
+ std::stringstream ss;
+ ss << "Tensor buffer is not direct or has invalid capacity";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ env->DeleteLocalRef(jshape);
+ env->DeleteLocalRef(jbuffer);
+ return nullptr;
+ }
+
+ const size_t elementSize = executorch::runtime::elementSize(scalar_type);
+ const jlong expectedElements = static_cast(numel);
+ const jlong expectedBytes = expectedElements * static_cast(elementSize);
+ const bool matchesElements = dataCapacity == expectedElements;
+ const bool matchesBytes = dataCapacity == expectedBytes;
+
+ if (!matchesElements && !matchesBytes) {
+ std::stringstream ss;
+ ss << "Tensor dimensions(elements number: " << numel
+ << ") inconsistent with buffer capacity " << dataCapacity
+ << " (element size bytes: " << elementSize << ")";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ env->DeleteLocalRef(jshape);
+ env->DeleteLocalRef(jbuffer);
+ return nullptr;
+ }
+
+ void* data = env->GetDirectBufferAddress(jbuffer);
+ TensorPtr result = from_blob(data, shape_vec, scalar_type);
- for (int i = 0; i < result.get().size(); i++) {
- auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]);
- jresult->setElement(i, *jevalue);
+ env->DeleteLocalRef(jshape);
+ env->DeleteLocalRef(jbuffer);
+
+ return result;
+}
+
+// Helper to create Java EValue from native EValue
+jobject newJEValueFromEValue(JNIEnv* env, EValue evalue) {
+ g_jni_cache.init(env);
+
+ if (evalue.isTensor()) {
+ jobject jtensor = newJTensorFromTensor(env, evalue.toTensor());
+ if (jtensor == nullptr) {
+ return nullptr;
}
- return jresult;
+ jobject result = env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class, g_jni_cache.evalue_from_tensor, jtensor);
+ env->DeleteLocalRef(jtensor);
+ return result;
+ } else if (evalue.isInt()) {
+ return env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class, g_jni_cache.evalue_from_long, evalue.toInt());
+ } else if (evalue.isDouble()) {
+ return env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class, g_jni_cache.evalue_from_double, evalue.toDouble());
+ } else if (evalue.isBool()) {
+ return env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class,
+ g_jni_cache.evalue_from_bool,
+ static_cast(evalue.toBool()));
+ } else if (evalue.isString()) {
+ std::string str =
+ std::string(evalue.toString().begin(), evalue.toString().end());
+ jstring jstr = env->NewStringUTF(str.c_str());
+ jobject result = env->CallStaticObjectMethod(
+ g_jni_cache.evalue_class, g_jni_cache.evalue_from_string, jstr);
+ env->DeleteLocalRef(jstr);
+ return result;
}
- facebook::jni::local_ref>
- readLogBuffer() {
- return readLogBufferUtil();
+ std::stringstream ss;
+ ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
+}
+
+// Helper to get TensorPtr from Java EValue
+TensorPtr JEValueToTensorImpl(JNIEnv* env, jobject jevalue) {
+ g_jni_cache.init(env);
+
+ jint typeCode = env->GetIntField(jevalue, g_jni_cache.evalue_mTypeCode);
+ if (typeCode == kTypeCodeTensor) {
+ jobject jtensor =
+ env->CallObjectMethod(jevalue, g_jni_cache.evalue_toTensor);
+ TensorPtr result = newTensorFromJTensor(env, jtensor);
+ env->DeleteLocalRef(jtensor);
+ return result;
+ }
+
+ std::stringstream ss;
+ ss << "Unknown EValue typeCode: " << typeCode;
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str().c_str());
+ return nullptr;
+}
+
+} // namespace
+
+extern "C" {
+
+JNIEXPORT jlong JNICALL
+Java_org_pytorch_executorch_Module_nativeCreate(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jstring modelPath,
+ jint loadMode,
+ jint numThreads) {
+ auto* native = new executorch::extension::ExecuTorchJni(env, modelPath, loadMode, numThreads);
+ return reinterpret_cast(native);
+}
+
+JNIEXPORT void JNICALL
+Java_org_pytorch_executorch_Module_nativeDestroy(
+ JNIEnv* /* env */,
+ jclass /* clazz */,
+ jlong nativeHandle) {
+ if (nativeHandle != 0) {
+ auto* native = reinterpret_cast(nativeHandle);
+ delete native;
}
+}
- static facebook::jni::local_ref>
- readLogBufferStatic(facebook::jni::alias_ref) {
- return readLogBufferUtil();
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeExecute(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong nativeHandle,
+ jstring methodName,
+ jobjectArray jinputs) {
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return nullptr;
}
- static facebook::jni::local_ref>
- readLogBufferUtil() {
-#ifdef __ANDROID__
+ g_jni_cache.init(env);
+
+ std::string method = jstring_to_string(env, methodName);
+ jsize inputSize = jinputs != nullptr ? env->GetArrayLength(jinputs) : 0;
- facebook::jni::local_ref> ret;
-
- access_log_buffer([&](std::vector& buffer) {
- const auto size = buffer.size();
- ret = facebook::jni::JArrayClass::newArray(size);
- for (auto i = 0u; i < size; i++) {
- const auto& entry = buffer[i];
- // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL
- // MESSAGE".
- std::stringstream ss;
- ss << "[" << entry.timestamp << " " << entry.function << " "
- << entry.filename << ":" << entry.line << "] "
- << static_cast(entry.level) << " " << entry.message;
-
- facebook::jni::local_ref jstr_message =
- facebook::jni::make_jstring(ss.str().c_str());
- (*ret)[i] = jstr_message;
+ // If no inputs is given, it will run with sample inputs (ones)
+ if (inputSize == 0) {
+ auto result = native->module_->load_method(method);
+ if (result != Error::Ok) {
+ std::stringstream ss;
+ ss << "Cannot get method names [Native Error: 0x" << std::hex
+ << std::uppercase << static_cast(result) << "]";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(result), ss.str());
+ return nullptr;
+ }
+ auto* underlying_method = native->get_method(method);
+ if (underlying_method == nullptr) {
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), "Method not found: " + method);
+ return nullptr;
+ }
+ auto&& buf = prepare_input_tensors(*underlying_method);
+ result = underlying_method->execute();
+ if (result != Error::Ok) {
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(result), "Execution failed for method: " + method);
+ return nullptr;
+ }
+
+ jobjectArray jresult =
+ env->NewObjectArray(underlying_method->outputs_size(), g_jni_cache.evalue_class, nullptr);
+
+ for (int i = 0; i < underlying_method->outputs_size(); i++) {
+ jobject jevalue = newJEValueFromEValue(env, underlying_method->get_output(i));
+ env->SetObjectArrayElement(jresult, i, jevalue);
+ if (jevalue != nullptr) {
+ env->DeleteLocalRef(jevalue);
}
- });
+ }
+ return jresult;
+ }
- return ret;
-#else
- return facebook::jni::JArrayClass::newArray(0);
-#endif
+ std::vector evalues;
+ std::vector tensors;
+
+ for (int i = 0; i < inputSize; i++) {
+ jobject jevalue = env->GetObjectArrayElement(jinputs, i);
+ jint typeCode = env->GetIntField(jevalue, g_jni_cache.evalue_mTypeCode);
+
+ if (typeCode == kTypeCodeTensor) {
+ tensors.emplace_back(JEValueToTensorImpl(env, jevalue));
+ evalues.emplace_back(tensors.back());
+ } else if (typeCode == kTypeCodeInt) {
+ jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData);
+ jclass longClass = env->FindClass("java/lang/Long");
+ jmethodID longValue = env->GetMethodID(longClass, "longValue", "()J");
+ jlong value = env->CallLongMethod(mData, longValue);
+ evalues.emplace_back(static_cast(value));
+ env->DeleteLocalRef(mData);
+ env->DeleteLocalRef(longClass);
+ } else if (typeCode == kTypeCodeDouble) {
+ jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData);
+ jclass doubleClass = env->FindClass("java/lang/Double");
+ jmethodID doubleValue = env->GetMethodID(doubleClass, "doubleValue", "()D");
+ jdouble value = env->CallDoubleMethod(mData, doubleValue);
+ evalues.emplace_back(static_cast(value));
+ env->DeleteLocalRef(mData);
+ env->DeleteLocalRef(doubleClass);
+ } else if (typeCode == kTypeCodeBool) {
+ jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData);
+ jclass boolClass = env->FindClass("java/lang/Boolean");
+ jmethodID boolValue = env->GetMethodID(boolClass, "booleanValue", "()Z");
+ jboolean value = env->CallBooleanMethod(mData, boolValue);
+ evalues.emplace_back(static_cast(value));
+ env->DeleteLocalRef(mData);
+ env->DeleteLocalRef(boolClass);
+ }
+ env->DeleteLocalRef(jevalue);
}
- jboolean etdump() {
#ifdef EXECUTORCH_ANDROID_PROFILING
- executorch::etdump::ETDumpGen* etdumpgen =
- (executorch::etdump::ETDumpGen*)module_->event_tracer();
- auto etdump_data = etdumpgen->get_etdump_data();
-
- if (etdump_data.buf != nullptr && etdump_data.size > 0) {
- int etdump_file =
- open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644);
- if (etdump_file == -1) {
- ET_LOG(Error, "Cannot create result.etdump error: %d", errno);
- return false;
- }
- ssize_t bytes_written =
- write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size);
- if (bytes_written == -1) {
- ET_LOG(Error, "Cannot write result.etdump error: %d", errno);
- return false;
- } else {
- ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written);
- }
- close(etdump_file);
- free(etdump_data.buf);
- return true;
- } else {
- ET_LOG(Error, "No ETDump data available!");
- }
+ auto start = std::chrono::high_resolution_clock::now();
+ auto result = native->module_->execute(method, evalues);
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration =
+ std::chrono::duration_cast(end - start).count();
+ ET_LOG(Debug, "Execution time: %lld ms.", duration);
+#else
+ auto result = native->module_->execute(method, evalues);
#endif
- return false;
+
+ if (!result.ok()) {
+ executorch::jni_helper::throwExecutorchException(
+ env,
+ static_cast(result.error()),
+ "Execution failed for method: " + method);
+ return nullptr;
}
- facebook::jni::local_ref> getMethods() {
- const auto& names_result = module_->method_names();
- if (!names_result.ok()) {
- // Format hex string
- std::stringstream ss;
- ss << "Cannot get load module [Native Error: 0x" << std::hex
- << std::uppercase << static_cast(names_result.error())
- << "]";
+ jobjectArray jresult =
+ env->NewObjectArray(result.get().size(), g_jni_cache.evalue_class, nullptr);
- jni_helper::throwExecutorchException(
- static_cast(Error::InvalidArgument), ss.str());
- return {};
- }
- const auto& methods = names_result.get();
- facebook::jni::local_ref> ret =
- facebook::jni::JArrayClass::newArray(methods.size());
- int i = 0;
- for (auto s : methods) {
- facebook::jni::local_ref method_name =
- facebook::jni::make_jstring(s.c_str());
- (*ret)[i] = method_name;
- i++;
+ for (size_t i = 0; i < result.get().size(); i++) {
+ jobject jevalue = newJEValueFromEValue(env, result.get()[i]);
+ env->SetObjectArrayElement(jresult, i, jevalue);
+ if (jevalue != nullptr) {
+ env->DeleteLocalRef(jevalue);
}
- return ret;
}
+ return jresult;
+}
+
+JNIEXPORT jint JNICALL
+Java_org_pytorch_executorch_Module_nativeLoadMethod(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong nativeHandle,
+ jstring methodName) {
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return -1;
+ }
+ std::string method = jstring_to_string(env, methodName);
+ return static_cast(native->module_->load_method(method));
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeGetMethods(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong nativeHandle) {
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return nullptr;
+ }
+
+ const auto& names_result = native->module_->method_names();
+ if (!names_result.ok()) {
+ std::stringstream ss;
+ ss << "Cannot get load module [Native Error: 0x" << std::hex
+ << std::uppercase << static_cast(names_result.error()) << "]";
+ executorch::jni_helper::throwExecutorchException(
+ env, static_cast(Error::InvalidArgument), ss.str());
+ return nullptr;
+ }
+
+ const auto& methods = names_result.get();
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray ret = env->NewObjectArray(methods.size(), stringClass, nullptr);
+
+ int i = 0;
+ for (auto s : methods) {
+ jstring method_name = env->NewStringUTF(s.c_str());
+ env->SetObjectArrayElement(ret, i, method_name);
+ env->DeleteLocalRef(method_name);
+ i++;
+ }
+ env->DeleteLocalRef(stringClass);
+ return ret;
+}
- facebook::jni::local_ref> getUsedBackends(
- facebook::jni::alias_ref methodName) {
- auto methodMeta = module_->method_meta(methodName->toStdString()).get();
- std::unordered_set backends;
- for (auto i = 0; i < methodMeta.num_backends(); i++) {
- backends.insert(methodMeta.get_backend_name(i).get());
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeGetUsedBackends(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong nativeHandle,
+ jstring methodName) {
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return nullptr;
+ }
+
+ std::string method = jstring_to_string(env, methodName);
+ auto methodMeta = native->module_->method_meta(method).get();
+ std::unordered_set backends;
+ for (auto i = 0; i < methodMeta.num_backends(); i++) {
+ backends.insert(methodMeta.get_backend_name(i).get());
+ }
+
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray ret = env->NewObjectArray(backends.size(), stringClass, nullptr);
+
+ int i = 0;
+ for (auto s : backends) {
+ jstring backend_name = env->NewStringUTF(s.c_str());
+ env->SetObjectArrayElement(ret, i, backend_name);
+ env->DeleteLocalRef(backend_name);
+ i++;
+ }
+ env->DeleteLocalRef(stringClass);
+ return ret;
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeReadLogBuffer(
+ JNIEnv* env,
+ jclass /* clazz */,
+ jlong /* nativeHandle */) {
+#ifdef __ANDROID__
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray ret = nullptr;
+
+ access_log_buffer([&](std::vector& buffer) {
+ const auto size = buffer.size();
+ ret = env->NewObjectArray(size, stringClass, nullptr);
+ for (auto i = 0u; i < size; i++) {
+ const auto& entry = buffer[i];
+ std::stringstream ss;
+ ss << "[" << entry.timestamp << " " << entry.function << " "
+ << entry.filename << ":" << entry.line << "] "
+ << static_cast(entry.level) << " " << entry.message;
+ jstring jstr_message = env->NewStringUTF(ss.str().c_str());
+ env->SetObjectArrayElement(ret, i, jstr_message);
+ env->DeleteLocalRef(jstr_message);
}
+ });
+
+ env->DeleteLocalRef(stringClass);
+ return ret;
+#else
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray ret = env->NewObjectArray(0, stringClass, nullptr);
+ env->DeleteLocalRef(stringClass);
+ return ret;
+#endif
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_Module_nativeReadLogBufferStatic(
+ JNIEnv* env,
+ jclass clazz) {
+ return Java_org_pytorch_executorch_Module_nativeReadLogBuffer(env, clazz, 0);
+}
+
+JNIEXPORT jboolean JNICALL
+Java_org_pytorch_executorch_Module_nativeEtdump(
+ JNIEnv* /* env */,
+ jclass /* clazz */,
+ jlong nativeHandle) {
+#ifdef EXECUTORCH_ANDROID_PROFILING
+ auto* native = reinterpret_cast(nativeHandle);
+ if (native == nullptr) {
+ return JNI_FALSE;
+ }
- facebook::jni::local_ref> ret =
- facebook::jni::JArrayClass::newArray(backends.size());
- int i = 0;
- for (auto s : backends) {
- facebook::jni::local_ref backend_name =
- facebook::jni::make_jstring(s.c_str());
- (*ret)[i] = backend_name;
- i++;
+ executorch::etdump::ETDumpGen* etdumpgen =
+ (executorch::etdump::ETDumpGen*)native->module_->event_tracer();
+ auto etdump_data = etdumpgen->get_etdump_data();
+
+ if (etdump_data.buf != nullptr && etdump_data.size > 0) {
+ int etdump_file =
+ open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644);
+ if (etdump_file == -1) {
+ ET_LOG(Error, "Cannot create result.etdump error: %d", errno);
+ return JNI_FALSE;
+ }
+ ssize_t bytes_written =
+ write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size);
+ if (bytes_written == -1) {
+ ET_LOG(Error, "Cannot write result.etdump error: %d", errno);
+ return JNI_FALSE;
+ } else {
+ ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written);
}
- return ret;
- }
-
- static void registerNatives() {
- registerHybrid({
- makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
- makeNativeMethod("executeNative", ExecuTorchJni::execute),
- makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
- makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
- makeNativeMethod(
- "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic),
- makeNativeMethod("etdump", ExecuTorchJni::etdump),
- makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
- makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
- });
+ close(etdump_file);
+ free(etdump_data.buf);
+ return JNI_TRUE;
+ } else {
+ ET_LOG(Error, "No ETDump data available!");
}
-};
-} // namespace executorch::extension
+#endif
+ return JNI_FALSE;
+}
+
+} // extern "C"
#ifdef EXECUTORCH_BUILD_LLAMA_JNI
extern void register_natives_for_llm();
@@ -540,20 +722,72 @@ extern void register_natives_for_llm();
// No op if we don't build LLM
void register_natives_for_llm() {}
#endif
-extern void register_natives_for_runtime();
#ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING
-extern void register_natives_for_training();
+extern void register_natives_for_training(JNIEnv* env);
#else
// No op if we don't build training JNI
-void register_natives_for_training() {}
+void register_natives_for_training(JNIEnv* /* env */) {}
#endif
+void register_natives_for_runtime(JNIEnv* env);
+
+void register_natives_for_module(JNIEnv* env) {
+ jclass module_class = env->FindClass("org/pytorch/executorch/Module");
+ if (module_class == nullptr) {
+ ET_LOG(Error, "Failed to find Module class");
+ env->ExceptionClear();
+ return;
+ }
+
+ // clang-format off
+ static const JNINativeMethod methods[] = {
+ {"nativeCreate", "(Ljava/lang/String;II)J",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeCreate)},
+ {"nativeDestroy", "(J)V",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeDestroy)},
+ {"nativeExecute",
+ "(JLjava/lang/String;[Lorg/pytorch/executorch/EValue;)[Lorg/pytorch/executorch/EValue;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeExecute)},
+ {"nativeLoadMethod", "(JLjava/lang/String;)I",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeLoadMethod)},
+ {"nativeGetMethods", "(J)[Ljava/lang/String;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeGetMethods)},
+ {"nativeGetUsedBackends", "(JLjava/lang/String;)[Ljava/lang/String;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeGetUsedBackends)},
+ {"nativeReadLogBuffer", "(J)[Ljava/lang/String;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeReadLogBuffer)},
+ {"nativeReadLogBufferStatic", "()[Ljava/lang/String;",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeReadLogBufferStatic)},
+ {"nativeEtdump", "(J)Z",
+ reinterpret_cast(Java_org_pytorch_executorch_Module_nativeEtdump)},
+ };
+ // clang-format on
+
+ int num_methods = sizeof(methods) / sizeof(methods[0]);
+ int result = env->RegisterNatives(module_class, methods, num_methods);
+ if (result != JNI_OK) {
+ ET_LOG(Error, "Failed to register native methods for Module");
+ }
+
+ env->DeleteLocalRef(module_class);
+}
+
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
- return facebook::jni::initialize(vm, [] {
- executorch::extension::ExecuTorchJni::registerNatives();
- register_natives_for_llm();
- register_natives_for_runtime();
- register_natives_for_training();
- });
+ g_jvm = vm;
+ JNIEnv* env = nullptr;
+ if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) {
+ return JNI_ERR;
+ }
+
+ // Initialize the JNI cache
+ g_jni_cache.init(env);
+
+ // Register native methods
+ register_natives_for_module(env);
+ register_natives_for_llm();
+ register_natives_for_runtime(env);
+ register_natives_for_training(env);
+
+ return JNI_VERSION_1_6;
}
diff --git a/extension/android/jni/jni_layer_runtime.cpp b/extension/android/jni/jni_layer_runtime.cpp
index 890e1d0fad9..32e7866353a 100644
--- a/extension/android/jni/jni_layer_runtime.cpp
+++ b/extension/android/jni/jni_layer_runtime.cpp
@@ -6,67 +6,90 @@
* LICENSE file in the root directory of this source tree.
*/
-#include
#include
#include
#include
+#include
namespace executorch_jni {
namespace runtime = ::executorch::ET_RUNTIME_NAMESPACE;
-class AndroidRuntimeJni : public facebook::jni::JavaClass {
- public:
- constexpr static const char* kJavaDescriptor =
- "Lorg/pytorch/executorch/ExecuTorchRuntime;";
-
- static void registerNatives() {
- javaClassStatic()->registerNatives({
- makeNativeMethod(
- "getRegisteredOps", AndroidRuntimeJni::getRegisteredOps),
- makeNativeMethod(
- "getRegisteredBackends", AndroidRuntimeJni::getRegisteredBackends),
- });
- }
+} // namespace executorch_jni
- // Returns a string array of all registered ops
- static facebook::jni::local_ref>
- getRegisteredOps(facebook::jni::alias_ref) {
- auto kernels = runtime::get_registered_kernels();
- auto result = facebook::jni::JArrayClass::newArray(kernels.size());
+extern "C" {
- for (size_t i = 0; i < kernels.size(); ++i) {
- auto op = facebook::jni::make_jstring(kernels[i].name_);
- result->setElement(i, op.get());
- }
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredOps(
+ JNIEnv* env,
+ jclass /* clazz */) {
+ auto kernels = executorch_jni::runtime::get_registered_kernels();
+
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray result = env->NewObjectArray(kernels.size(), stringClass, nullptr);
- return result;
+ for (size_t i = 0; i < kernels.size(); ++i) {
+ jstring op = env->NewStringUTF(kernels[i].name_);
+ env->SetObjectArrayElement(result, i, op);
+ env->DeleteLocalRef(op);
}
- // Returns a string array of all registered backends
- static facebook::jni::local_ref>
- getRegisteredBackends(facebook::jni::alias_ref) {
- int num_backends = runtime::get_num_registered_backends();
- auto result = facebook::jni::JArrayClass::newArray(num_backends);
+ env->DeleteLocalRef(stringClass);
+ return result;
+}
+
+JNIEXPORT jobjectArray JNICALL
+Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredBackends(
+ JNIEnv* env,
+ jclass /* clazz */) {
+ int num_backends = executorch_jni::runtime::get_num_registered_backends();
- for (int i = 0; i < num_backends; ++i) {
- auto name_result = runtime::get_backend_name(i);
- const char* name = "";
+ jclass stringClass = env->FindClass("java/lang/String");
+ jobjectArray result = env->NewObjectArray(num_backends, stringClass, nullptr);
- if (name_result.ok()) {
- name = *name_result;
- }
+ for (int i = 0; i < num_backends; ++i) {
+ auto name_result = executorch_jni::runtime::get_backend_name(i);
+ const char* name = "";
- auto backend_str = facebook::jni::make_jstring(name);
- result->setElement(i, backend_str.get());
+ if (name_result.ok()) {
+ name = *name_result;
}
- return result;
+ jstring backend_str = env->NewStringUTF(name);
+ env->SetObjectArrayElement(result, i, backend_str);
+ env->DeleteLocalRef(backend_str);
}
-};
-} // namespace executorch_jni
+ env->DeleteLocalRef(stringClass);
+ return result;
+}
+
+} // extern "C"
+
+void register_natives_for_runtime(JNIEnv* env) {
+ jclass runtime_class = env->FindClass("org/pytorch/executorch/ExecuTorchRuntime");
+ if (runtime_class == nullptr) {
+ ET_LOG(Error, "Failed to find ExecuTorchRuntime class");
+ env->ExceptionClear();
+ return;
+ }
+
+ // clang-format off
+ static const JNINativeMethod methods[] = {
+ {"nativeGetRegisteredOps", "()[Ljava/lang/String;",
+ reinterpret_cast(
+ Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredOps)},
+ {"nativeGetRegisteredBackends", "()[Ljava/lang/String;",
+ reinterpret_cast(
+ Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredBackends)},
+ };
+ // clang-format on
+
+ int num_methods = sizeof(methods) / sizeof(methods[0]);
+ int result = env->RegisterNatives(runtime_class, methods, num_methods);
+ if (result != JNI_OK) {
+ ET_LOG(Error, "Failed to register native methods for ExecuTorchRuntime");
+ }
-void register_natives_for_runtime() {
- executorch_jni::AndroidRuntimeJni::registerNatives();
+ env->DeleteLocalRef(runtime_class);
}
diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp
index 5a5e9f24d2f..0641013a993 100644
--- a/extension/android/jni/jni_layer_training.cpp
+++ b/extension/android/jni/jni_layer_training.cpp
@@ -12,10 +12,12 @@
#include
#include
#include
+#include
#include
#include
#include
#include
+#include
#include
#include
@@ -28,19 +30,97 @@ using namespace torch::executor;
namespace executorch::extension {
-// Forward declarations from jni_layer.cpp
+// Full implementation of TensorHybrid for training module (fbjni-based)
class TensorHybrid : public facebook::jni::HybridClass {
public:
constexpr static const char* kJavaDescriptor =
"Lorg/pytorch/executorch/Tensor;";
static facebook::jni::local_ref
- newJTensorFromTensor(const executorch::aten::Tensor& tensor);
+ newJTensorFromTensor(const executorch::aten::Tensor& tensor) {
+ const auto scalarType = tensor.scalar_type();
+ if (scalar_type_to_java_dtype.count(scalarType) == 0) {
+ facebook::jni::throwNewJavaException(
+ "java/lang/IllegalArgumentException",
+ "executorch::aten::Tensor scalar type %d is not supported on java side",
+ static_cast(scalarType));
+ }
+ int jdtype = scalar_type_to_java_dtype.at(scalarType);
+
+ const auto& tensor_shape = tensor.sizes();
+ std::vector tensor_shape_vec;
+ for (const auto& s : tensor_shape) {
+ tensor_shape_vec.push_back(s);
+ }
+ facebook::jni::local_ref jTensorShape =
+ facebook::jni::make_long_array(tensor_shape_vec.size());
+ jTensorShape->setRegion(
+ 0, tensor_shape_vec.size(), tensor_shape_vec.data());
+
+ facebook::jni::local_ref jTensorBuffer =
+ facebook::jni::JByteBuffer::wrapBytes(
+ (uint8_t*)tensor.const_data_ptr(), tensor.nbytes());
+ jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
+
+ static auto cls = TensorHybrid::javaClassStatic();
+ static const auto jMethodNewTensor =
+ cls->getStaticMethod(
+ facebook::jni::local_ref,
+ facebook::jni::local_ref,
+ jint,
+ facebook::jni::local_ref)>("nativeNewTensor");
+ return jMethodNewTensor(
+ cls, std::move(jTensorBuffer), std::move(jTensorShape), jdtype, nullptr);
+ }
static TensorPtr newTensorFromJTensor(
- facebook::jni::alias_ref jtensor);
+ facebook::jni::alias_ref jtensor) {
+ static const auto dtypeMethod =
+ TensorHybrid::javaClassStatic()->getMethod("dtypeJniCode");
+ jint jdtype = dtypeMethod(jtensor);
+
+ static auto shapeField =
+ TensorHybrid::javaClassStatic()->getField("shape");
+ auto jshape = jtensor->getFieldValue(shapeField);
+
+ static const auto dataBufferMethod =
+ TensorHybrid::javaClassStatic()
+ ->getMethod()>(
+ "getRawDataBuffer");
+ facebook::jni::local_ref jbuffer =
+ dataBufferMethod(jtensor);
+
+ const auto rank = jshape->size();
+
+ std::vector shapeArr(rank);
+ jshape->getRegion(0, rank, shapeArr.data());
+
+ std::vector sizes_vec;
+ sizes_vec.reserve(rank);
+
+ int64_t numel = 1;
+ for (int i = 0; i < rank; i++) {
+ sizes_vec.push_back(shapeArr[i]);
+ }
+ for (int i = rank - 1; i >= 0; --i) {
+ numel *= shapeArr[i];
+ }
+
+ JNIEnv* jni = facebook::jni::Environment::current();
+ void* dataPtr = jni->GetDirectBufferAddress(jbuffer.get());
+ if (java_dtype_to_scalar_type.count(jdtype) == 0) {
+ facebook::jni::throwNewJavaException(
+ "java/lang/IllegalArgumentException",
+ "Unknown Tensor jdtype: %d",
+ jdtype);
+ }
+
+ ScalarType scalarType = java_dtype_to_scalar_type.at(jdtype);
+ return from_blob(dataPtr, sizes_vec, scalarType);
+ }
};
+// Full implementation of JEValue for training module (fbjni-based)
class JEValue : public facebook::jni::JavaClass {
public:
constexpr static const char* kJavaDescriptor =
@@ -53,10 +133,69 @@ class JEValue : public facebook::jni::JavaClass {
constexpr static int kTypeCodeBool = 5;
static facebook::jni::local_ref newJEValueFromEValue(
- runtime::EValue evalue);
+ runtime::EValue evalue) {
+ if (evalue.isTensor()) {
+ static auto jMethodTensor =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(
+ facebook::jni::local_ref)>("from");
+ return jMethodTensor(
+ JEValue::javaClassStatic(),
+ TensorHybrid::newJTensorFromTensor(evalue.toTensor()));
+ } else if (evalue.isInt()) {
+ static auto jMethodInt =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(jlong)>(
+ "from");
+ return jMethodInt(JEValue::javaClassStatic(), evalue.toInt());
+ } else if (evalue.isDouble()) {
+ static auto jMethodDouble =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(jdouble)>(
+ "from");
+ return jMethodDouble(JEValue::javaClassStatic(), evalue.toDouble());
+ } else if (evalue.isBool()) {
+ static auto jMethodBool =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(jboolean)>(
+ "from");
+ return jMethodBool(JEValue::javaClassStatic(), evalue.toBool());
+ } else if (evalue.isString()) {
+ static auto jMethodStr =
+ JEValue::javaClassStatic()
+ ->getStaticMethod(
+ facebook::jni::local_ref)>("from");
+ std::string str =
+ std::string(evalue.toString().begin(), evalue.toString().end());
+ return jMethodStr(
+ JEValue::javaClassStatic(), facebook::jni::make_jstring(str));
+ }
+ facebook::jni::throwNewJavaException(
+ "java/lang/IllegalArgumentException",
+ "Unknown EValue type: %d",
+ static_cast(evalue.tag));
+ return nullptr;
+ }
static TensorPtr JEValueToTensorImpl(
- facebook::jni::alias_ref JEValue);
+ facebook::jni::alias_ref jevalue) {
+ static const auto typeCodeField =
+ JEValue::javaClassStatic()->getField("mTypeCode");
+ const auto typeCode = jevalue->getFieldValue(typeCodeField);
+ if (typeCode == JEValue::kTypeCodeTensor) {
+ static const auto jMethodGetTensor =
+ JEValue::javaClassStatic()
+ ->getMethod()>(
+ "toTensor");
+ auto tensor = jMethodGetTensor(jevalue);
+ return TensorHybrid::newTensorFromJTensor(tensor);
+ }
+ facebook::jni::throwNewJavaException(
+ "java/lang/IllegalArgumentException",
+ "Unknown EValue typeCode: %d",
+ typeCode);
+ return nullptr;
+ }
};
class ExecuTorchTrainingJni
@@ -345,7 +484,7 @@ class SGDHybrid : public facebook::jni::HybridClass {
} // namespace executorch::extension
// Function to register training module natives
-void register_natives_for_training() {
+void register_natives_for_training(JNIEnv* /* env */) {
executorch::extension::ExecuTorchTrainingJni::registerNatives();
executorch::extension::SGDHybrid::registerNatives();
};