diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java index ab3b77ff1fb..e0122e3979e 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/EValue.java @@ -8,7 +8,6 @@ package org.pytorch.executorch; -import com.facebook.jni.annotations.DoNotStrip; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Locale; @@ -33,7 +32,6 @@ *

Warning: These APIs are experimental and subject to change without notice */ @Experimental -@DoNotStrip public class EValue { private static final int TYPE_CODE_NONE = 0; @@ -47,52 +45,50 @@ public class EValue { "None", "Tensor", "String", "Double", "Int", "Bool", }; - @DoNotStrip private final int mTypeCode; - @DoNotStrip private Object mData; + final int mTypeCode; + Object mData; - @DoNotStrip private EValue(int typeCode) { this.mTypeCode = typeCode; } - @DoNotStrip public boolean isNone() { return TYPE_CODE_NONE == this.mTypeCode; } - @DoNotStrip + public boolean isTensor() { return TYPE_CODE_TENSOR == this.mTypeCode; } - @DoNotStrip + public boolean isBool() { return TYPE_CODE_BOOL == this.mTypeCode; } - @DoNotStrip + public boolean isInt() { return TYPE_CODE_INT == this.mTypeCode; } - @DoNotStrip + public boolean isDouble() { return TYPE_CODE_DOUBLE == this.mTypeCode; } - @DoNotStrip + public boolean isString() { return TYPE_CODE_STRING == this.mTypeCode; } /** Creates a new {@code EValue} of type {@code Optional} that contains no value. */ - @DoNotStrip + public static EValue optionalNone() { return new EValue(TYPE_CODE_NONE); } /** Creates a new {@code EValue} of type {@code Tensor}. */ - @DoNotStrip + public static EValue from(Tensor tensor) { final EValue iv = new EValue(TYPE_CODE_TENSOR); iv.mData = tensor; @@ -100,7 +96,7 @@ public static EValue from(Tensor tensor) { } /** Creates a new {@code EValue} of type {@code bool}. */ - @DoNotStrip + public static EValue from(boolean value) { final EValue iv = new EValue(TYPE_CODE_BOOL); iv.mData = value; @@ -108,7 +104,7 @@ public static EValue from(boolean value) { } /** Creates a new {@code EValue} of type {@code int}. */ - @DoNotStrip + public static EValue from(long value) { final EValue iv = new EValue(TYPE_CODE_INT); iv.mData = value; @@ -116,7 +112,7 @@ public static EValue from(long value) { } /** Creates a new {@code EValue} of type {@code double}. */ - @DoNotStrip + public static EValue from(double value) { final EValue iv = new EValue(TYPE_CODE_DOUBLE); iv.mData = value; @@ -124,38 +120,38 @@ public static EValue from(double value) { } /** Creates a new {@code EValue} of type {@code str}. */ - @DoNotStrip + public static EValue from(String value) { final EValue iv = new EValue(TYPE_CODE_STRING); iv.mData = value; return iv; } - @DoNotStrip + public Tensor toTensor() { preconditionType(TYPE_CODE_TENSOR, mTypeCode); return (Tensor) mData; } - @DoNotStrip + public boolean toBool() { preconditionType(TYPE_CODE_BOOL, mTypeCode); return (boolean) mData; } - @DoNotStrip + public long toInt() { preconditionType(TYPE_CODE_INT, mTypeCode); return (long) mData; } - @DoNotStrip + public double toDouble() { preconditionType(TYPE_CODE_DOUBLE, mTypeCode); return (double) mData; } - @DoNotStrip + public String toStr() { preconditionType(TYPE_CODE_STRING, mTypeCode); return (String) mData; diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java index 8e2f259ef3a..dfa9f77b6dd 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java @@ -8,7 +8,6 @@ package org.pytorch.executorch; -import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; @@ -33,10 +32,16 @@ public static ExecuTorchRuntime getRuntime() { } /** Get all registered ops. */ - @DoNotStrip - public static native String[] getRegisteredOps(); + public static String[] getRegisteredOps() { + return nativeGetRegisteredOps(); + } + + private static native String[] nativeGetRegisteredOps(); /** Get all registered backends. */ - @DoNotStrip - public static native String[] getRegisteredBackends(); + public static String[] getRegisteredBackends() { + return nativeGetRegisteredBackends(); + } + + private static native String[] nativeGetRegisteredBackends(); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 6da76bf4b74..481165f4e21 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -9,8 +9,6 @@ package org.pytorch.executorch; import android.util.Log; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; import java.io.File; @@ -48,18 +46,18 @@ public class Module { /** Load mode for the module. Use memory locking and ignore errors. */ public static final int LOAD_MODE_MMAP_USE_MLOCK_IGNORE_ERRORS = 3; - private final HybridData mHybridData; + private long mNativeHandle; private final Map mMethodMetadata; - @DoNotStrip - private static native HybridData initHybrid( - String moduleAbsolutePath, int loadMode, int initHybrid); + private static native long nativeCreate(String moduleAbsolutePath, int loadMode, int numThreads); + + private static native void nativeDestroy(long nativeHandle); private Module(String moduleAbsolutePath, int loadMode, int numThreads) { ExecuTorchRuntime runtime = ExecuTorchRuntime.getRuntime(); - mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads); + mNativeHandle = nativeCreate(moduleAbsolutePath, loadMode, numThreads); mMethodMetadata = populateMethodMeta(); } @@ -75,7 +73,7 @@ Map populateMethodMeta() { return metadata; } - /** Lock protecting the non-thread safe methods in mHybridData. */ + /** Lock protecting the non-thread safe methods in native handle. */ private Lock mLock = new ReentrantLock(); /** @@ -138,18 +136,18 @@ public EValue[] forward(EValue... inputs) { public EValue[] execute(String methodName, EValue... inputs) { try { mLock.lock(); - if (!mHybridData.isValid()) { + if (mNativeHandle == 0) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return new EValue[0]; } - return executeNative(methodName, inputs); + return nativeExecute(mNativeHandle, methodName, inputs); } finally { mLock.unlock(); } } - @DoNotStrip - private native EValue[] executeNative(String methodName, EValue... inputs); + private static native EValue[] nativeExecute( + long nativeHandle, String methodName, EValue... inputs); /** * Load a method on this module. This might help with the first time inference performance, @@ -163,18 +161,17 @@ public EValue[] execute(String methodName, EValue... inputs) { public int loadMethod(String methodName) { try { mLock.lock(); - if (!mHybridData.isValid()) { + if (mNativeHandle == 0) { Log.e("ExecuTorch", "Attempt to use a destroyed module"); return 0x2; // InvalidState } - return loadMethodNative(methodName); + return nativeLoadMethod(mNativeHandle, methodName); } finally { mLock.unlock(); } } - @DoNotStrip - private native int loadMethodNative(String methodName); + private static native int nativeLoadMethod(long nativeHandle, String methodName); /** * Returns the names of the backends in a certain method. @@ -182,16 +179,22 @@ public int loadMethod(String methodName) { * @param methodName method name to query * @return an array of backend name */ - @DoNotStrip - private native String[] getUsedBackends(String methodName); + public String[] getUsedBackends(String methodName) { + return nativeGetUsedBackends(mNativeHandle, methodName); + } + + private static native String[] nativeGetUsedBackends(long nativeHandle, String methodName); /** * Returns the names of methods. * * @return name of methods in this Module */ - @DoNotStrip - public native String[] getMethods(); + public String[] getMethods() { + return nativeGetMethods(mNativeHandle); + } + + private static native String[] nativeGetMethods(long nativeHandle); /** * Get the corresponding @MethodMetadata for a method @@ -211,20 +214,18 @@ public MethodMetadata getMethodMetadata(String name) { return methodMetadata; } - @DoNotStrip - private static native String[] readLogBufferStaticNative(); + private static native String[] nativeReadLogBufferStatic(); public static String[] readLogBufferStatic() { - return readLogBufferStaticNative(); + return nativeReadLogBufferStatic(); } /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { - return readLogBufferNative(); + return nativeReadLogBuffer(mNativeHandle); } - @DoNotStrip - private native String[] readLogBufferNative(); + private static native String[] nativeReadLogBuffer(long nativeHandle); /** * Dump the ExecuTorch ETRecord file to /data/local/tmp/result.etdump. @@ -234,19 +235,25 @@ public String[] readLogBuffer() { * @return true if the etdump was successfully written, false otherwise. */ @Experimental - @DoNotStrip - public native boolean etdump(); + public boolean etdump() { + return nativeEtdump(mNativeHandle); + } + + private static native boolean nativeEtdump(long nativeHandle); /** * Explicitly destroys the native Module object. Calling this method is not required, as the * native object will be destroyed when this object is garbage-collected. However, the timing of * garbage collection is not guaranteed, so proactively calling {@code destroy} can free memory - * more quickly. See {@link com.facebook.jni.HybridData#resetNative}. + * more quickly. */ public void destroy() { if (mLock.tryLock()) { try { - mHybridData.resetNative(); + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } } finally { mLock.unlock(); } @@ -257,4 +264,13 @@ public void destroy() { + " released."); } } + + @Override + protected void finalize() throws Throwable { + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } + super.finalize(); + } } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index e8c0a918b13..a103e3691c2 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -9,8 +9,6 @@ package org.pytorch.executorch; import android.util.Log; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -53,7 +51,7 @@ public abstract class Tensor { private static final String ERROR_MSG_DATA_BUFFER_MUST_BE_DIRECT = "Data buffer must be direct (java.nio.ByteBuffer#allocateDirect)"; - @DoNotStrip final long[] shape; + final long[] shape; private static final int BYTE_SIZE_BYTES = 1; private static final int INT_SIZE_BYTES = 4; @@ -468,7 +466,8 @@ public static Tensor zeros(long[] shape, DType dtype) { } } - @DoNotStrip private HybridData mHybridData; + // Native handle for tensor data (unused in pure JNI but kept for API compatibility) + private long mNativeHandle; private Tensor(long[] shape) { checkShape(shape); @@ -501,7 +500,6 @@ public long[] shape() { public abstract DType dtype(); // Called from native - @DoNotStrip int dtypeJniCode() { return dtype().jniCode; } @@ -572,7 +570,6 @@ public double[] getDataAsDoubleArray() { "Tensor of type " + getClass().getSimpleName() + " cannot return data as double array."); } - @DoNotStrip Buffer getRawDataBuffer() { throw new IllegalStateException( "Tensor of type " + getClass().getSimpleName() + " cannot " + "return raw data buffer."); @@ -889,9 +886,8 @@ private static void checkShapeAndDataCapacityConsistency(int dataCapacity, long[ // endregion checks // Called from native - @DoNotStrip private static Tensor nativeNewTensor( - ByteBuffer data, long[] shape, int dtype, HybridData hybridData) { + ByteBuffer data, long[] shape, int dtype, long nativeHandle) { Tensor tensor = null; if (DType.FLOAT.jniCode == dtype) { @@ -911,7 +907,7 @@ private static Tensor nativeNewTensor( } else { tensor = new Tensor_unsupported(data, shape, DType.fromJniCode(dtype)); } - tensor.mHybridData = hybridData; + tensor.mNativeHandle = nativeHandle; return tensor; } diff --git a/extension/android/jni/jni_helper.cpp b/extension/android/jni/jni_helper.cpp index 6491524c7ac..37f9b271e52 100644 --- a/extension/android/jni/jni_helper.cpp +++ b/extension/android/jni/jni_helper.cpp @@ -10,6 +10,60 @@ namespace executorch::jni_helper { +void throwExecutorchException( + JNIEnv* env, + uint32_t errorCode, + const std::string& details) { + if (!env) { + return; + } + + // Find the exception class + jclass exceptionClass = + env->FindClass("org/pytorch/executorch/ExecutorchRuntimeException"); + if (exceptionClass == nullptr) { + // Class not found, clear the exception and return + env->ExceptionClear(); + return; + } + + // Find the static factory method: makeExecutorchException(int, String) + jmethodID makeExceptionMethod = env->GetStaticMethodID( + exceptionClass, + "makeExecutorchException", + "(ILjava/lang/String;)Ljava/lang/RuntimeException;"); + if (makeExceptionMethod == nullptr) { + env->ExceptionClear(); + env->DeleteLocalRef(exceptionClass); + return; + } + + // Create the details string + jstring jDetails = env->NewStringUTF(details.c_str()); + if (jDetails == nullptr) { + env->ExceptionClear(); + env->DeleteLocalRef(exceptionClass); + return; + } + + // Call the factory method to create the exception object + jobject exception = env->CallStaticObjectMethod( + exceptionClass, + makeExceptionMethod, + static_cast(errorCode), + jDetails); + + env->DeleteLocalRef(jDetails); + + if (exception != nullptr) { + env->Throw(static_cast(exception)); + env->DeleteLocalRef(exception); + } + + env->DeleteLocalRef(exceptionClass); +} + +#if EXECUTORCH_HAS_FBJNI void throwExecutorchException(uint32_t errorCode, const std::string& details) { // Get the current JNI environment auto env = facebook::jni::Environment::current(); @@ -34,5 +88,6 @@ void throwExecutorchException(uint32_t errorCode, const std::string& details) { auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails); facebook::jni::throwNewJavaException(exception.get()); } +#endif } // namespace executorch::jni_helper diff --git a/extension/android/jni/jni_helper.h b/extension/android/jni/jni_helper.h index 898c1619d9c..683a3cfe447 100644 --- a/extension/android/jni/jni_helper.h +++ b/extension/android/jni/jni_helper.h @@ -8,9 +8,16 @@ #pragma once -#include +#include #include +#if __has_include() +#include +#define EXECUTORCH_HAS_FBJNI 1 +#else +#define EXECUTORCH_HAS_FBJNI 0 +#endif + namespace executorch::jni_helper { /** @@ -18,6 +25,25 @@ namespace executorch::jni_helper { * code and details. Uses the Java factory method * ExecutorchRuntimeException.makeExecutorchException(int, String). * + * This version takes JNIEnv* directly and works with pure JNI. + * + * @param env The JNI environment. + * @param errorCode The error code from the C++ Executorch runtime. + * @param details Additional details to include in the exception message. + */ +void throwExecutorchException( + JNIEnv* env, + uint32_t errorCode, + const std::string& details); + +#if EXECUTORCH_HAS_FBJNI +/** + * Throws a Java ExecutorchRuntimeException corresponding to the given error + * code and details. Uses the Java factory method + * ExecutorchRuntimeException.makeExecutorchException(int, String). + * + * This version uses fbjni to get the current JNI environment. + * * @param errorCode The error code from the C++ Executorch runtime. * @param details Additional details to include in the exception message. */ @@ -29,5 +55,6 @@ struct JExecutorchRuntimeException static constexpr auto kJavaDescriptor = "Lorg/pytorch/executorch/ExecutorchRuntimeException;"; }; +#endif } // namespace executorch::jni_helper diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 1f8457e00c5..93c6e111a4d 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include @@ -39,223 +41,120 @@ #include #endif -#include -#include - using namespace executorch::extension; using namespace torch::executor; -namespace executorch::extension { -class TensorHybrid : public facebook::jni::HybridClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/Tensor;"; - - explicit TensorHybrid(executorch::aten::Tensor tensor) {} - - static facebook::jni::local_ref - newJTensorFromTensor(const executorch::aten::Tensor& tensor) { - // Java wrapper currently only supports contiguous tensors. +// Helper to convert jstring to std::string (defined outside namespace for broad access) +static std::string jstring_to_string(JNIEnv* env, jstring jstr) { + if (jstr == nullptr) { + return ""; + } + const char* chars = env->GetStringUTFChars(jstr, nullptr); + if (chars == nullptr) { + return ""; + } + std::string result(chars); + env->ReleaseStringUTFChars(jstr, chars); + return result; +} - const auto scalarType = tensor.scalar_type(); - int jdtype = scalar_type_to_java_dtype.at(scalarType); - if (scalar_type_to_java_dtype.count(scalarType) == 0) { - std::stringstream ss; - ss << "executorch::aten::Tensor scalar [java] type: " << jdtype - << " is not supported on java side"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); +namespace { + +// Global JavaVM pointer for obtaining JNIEnv in callbacks +JavaVM* g_jvm = nullptr; + +// EValue type codes (must match Java EValue class) +constexpr int kTypeCodeNone = 0; +constexpr int kTypeCodeTensor = 1; +constexpr int kTypeCodeString = 2; +constexpr int kTypeCodeDouble = 3; +constexpr int kTypeCodeInt = 4; +constexpr int kTypeCodeBool = 5; + +// Cached class and method IDs for performance +struct JniCache { + jclass tensor_class = nullptr; + jclass evalue_class = nullptr; + jmethodID tensor_nativeNewTensor = nullptr; + jmethodID tensor_dtypeJniCode = nullptr; + jmethodID tensor_getRawDataBuffer = nullptr; + jfieldID tensor_shape = nullptr; + jmethodID evalue_from_tensor = nullptr; + jmethodID evalue_from_long = nullptr; + jmethodID evalue_from_double = nullptr; + jmethodID evalue_from_bool = nullptr; + jmethodID evalue_from_string = nullptr; + jmethodID evalue_toTensor = nullptr; + jfieldID evalue_mTypeCode = nullptr; + jfieldID evalue_mData = nullptr; + + bool initialized = false; + + void init(JNIEnv* env) { + if (initialized) { + return; } - const auto& tensor_shape = tensor.sizes(); - std::vector tensor_shape_vec; - for (const auto& s : tensor_shape) { - tensor_shape_vec.push_back(s); - } - facebook::jni::local_ref jTensorShape = - facebook::jni::make_long_array(tensor_shape_vec.size()); - jTensorShape->setRegion( - 0, tensor_shape_vec.size(), tensor_shape_vec.data()); - - static auto cls = TensorHybrid::javaClassStatic(); - // Note: this is safe as long as the data stored in tensor is valid; the - // data won't go out of scope as long as the Method for the inference is - // valid and there is no other inference call. Java layer picks up this - // value immediately so the data is valid. - facebook::jni::local_ref jTensorBuffer = - facebook::jni::JByteBuffer::wrapBytes( - (uint8_t*)tensor.data_ptr(), tensor.nbytes()); - jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder()); - - static const auto jMethodNewTensor = - cls->getStaticMethod( - facebook::jni::alias_ref, - facebook::jni::alias_ref, - jint, - facebook::jni::alias_ref)>("nativeNewTensor"); - return jMethodNewTensor( - cls, jTensorBuffer, jTensorShape, jdtype, makeCxxInstance(tensor)); - } - - static TensorPtr newTensorFromJTensor( - facebook::jni::alias_ref jtensor) { - static auto cls = TensorHybrid::javaClassStatic(); - static const auto dtypeMethod = cls->getMethod("dtypeJniCode"); - jint jdtype = dtypeMethod(jtensor); - - static const auto shapeField = cls->getField("shape"); - auto jshape = jtensor->getFieldValue(shapeField); - - static auto dataBufferMethod = cls->getMethod< - facebook::jni::local_ref()>( - "getRawDataBuffer"); - facebook::jni::local_ref jbuffer = - dataBufferMethod(jtensor); - - const auto rank = jshape->size(); - - const auto shapeArr = jshape->getRegion(0, rank); - std::vector shape_vec; - shape_vec.reserve(rank); - - int64_t numel = 1; - for (int i = 0; i < rank; i++) { - shape_vec.push_back(shapeArr[i]); + // Cache Tensor class and methods + jclass local_tensor_class = env->FindClass("org/pytorch/executorch/Tensor"); + if (local_tensor_class != nullptr) { + tensor_class = static_cast(env->NewGlobalRef(local_tensor_class)); + env->DeleteLocalRef(local_tensor_class); + + tensor_nativeNewTensor = env->GetStaticMethodID( + tensor_class, + "nativeNewTensor", + "(Ljava/nio/ByteBuffer;[JIJ)Lorg/pytorch/executorch/Tensor;"); + tensor_dtypeJniCode = env->GetMethodID(tensor_class, "dtypeJniCode", "()I"); + tensor_getRawDataBuffer = + env->GetMethodID(tensor_class, "getRawDataBuffer", "()Ljava/nio/Buffer;"); + tensor_shape = env->GetFieldID(tensor_class, "shape", "[J"); } - for (int i = rank - 1; i >= 0; --i) { - numel *= shapeArr[i]; - } - JNIEnv* jni = facebook::jni::Environment::current(); - if (java_dtype_to_scalar_type.count(jdtype) == 0) { - std::stringstream ss; - ss << "Unknown Tensor jdtype: [" << jdtype << "]"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - } - ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); - const jlong dataCapacity = jni->GetDirectBufferCapacity(jbuffer.get()); - if (dataCapacity < 0) { - std::stringstream ss; - ss << "Tensor buffer is not direct or has invalid capacity"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - } - const size_t elementSize = executorch::runtime::elementSize(scalar_type); - const jlong expectedElements = static_cast(numel); - const jlong expectedBytes = - expectedElements * static_cast(elementSize); - const bool matchesElements = dataCapacity == expectedElements; - const bool matchesBytes = dataCapacity == expectedBytes; - if (!matchesElements && !matchesBytes) { - std::stringstream ss; - ss << "Tensor dimensions(elements number: " << numel - << ") inconsistent with buffer capacity " << dataCapacity - << " (element size bytes: " << elementSize << ")"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - } - return from_blob( - jni->GetDirectBufferAddress(jbuffer.get()), shape_vec, scalar_type); - } - - private: - friend HybridBase; -}; -class JEValue : public facebook::jni::JavaClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/EValue;"; - - constexpr static int kTypeCodeTensor = 1; - constexpr static int kTypeCodeString = 2; - constexpr static int kTypeCodeDouble = 3; - constexpr static int kTypeCodeInt = 4; - constexpr static int kTypeCodeBool = 5; - - static facebook::jni::local_ref newJEValueFromEValue(EValue evalue) { - if (evalue.isTensor()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod( - facebook::jni::local_ref)>("from"); - return jMethodTensor( - JEValue::javaClassStatic(), - TensorHybrid::newJTensorFromTensor(evalue.toTensor())); - } else if (evalue.isInt()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod(jlong)>( - "from"); - return jMethodTensor(JEValue::javaClassStatic(), evalue.toInt()); - } else if (evalue.isDouble()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod(jdouble)>( - "from"); - return jMethodTensor(JEValue::javaClassStatic(), evalue.toDouble()); - } else if (evalue.isBool()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod(jboolean)>( - "from"); - return jMethodTensor(JEValue::javaClassStatic(), evalue.toBool()); - } else if (evalue.isString()) { - static auto jMethodTensor = - JEValue::javaClassStatic() - ->getStaticMethod( - facebook::jni::local_ref)>("from"); - std::string str = - std::string(evalue.toString().begin(), evalue.toString().end()); - return jMethodTensor( - JEValue::javaClassStatic(), facebook::jni::make_jstring(str)); - } - std::stringstream ss; - ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]"; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - return {}; - } - - static TensorPtr JEValueToTensorImpl( - facebook::jni::alias_ref JEValue) { - static const auto typeCodeField = - JEValue::javaClassStatic()->getField("mTypeCode"); - const auto typeCode = JEValue->getFieldValue(typeCodeField); - if (JEValue::kTypeCodeTensor == typeCode) { - static const auto jMethodGetTensor = - JEValue::javaClassStatic() - ->getMethod()>( - "toTensor"); - auto jtensor = jMethodGetTensor(JEValue); - return TensorHybrid::newTensorFromJTensor(jtensor); + // Cache EValue class and methods + jclass local_evalue_class = env->FindClass("org/pytorch/executorch/EValue"); + if (local_evalue_class != nullptr) { + evalue_class = static_cast(env->NewGlobalRef(local_evalue_class)); + env->DeleteLocalRef(local_evalue_class); + + evalue_from_tensor = env->GetStaticMethodID( + evalue_class, + "from", + "(Lorg/pytorch/executorch/Tensor;)Lorg/pytorch/executorch/EValue;"); + evalue_from_long = + env->GetStaticMethodID(evalue_class, "from", "(J)Lorg/pytorch/executorch/EValue;"); + evalue_from_double = + env->GetStaticMethodID(evalue_class, "from", "(D)Lorg/pytorch/executorch/EValue;"); + evalue_from_bool = + env->GetStaticMethodID(evalue_class, "from", "(Z)Lorg/pytorch/executorch/EValue;"); + evalue_from_string = env->GetStaticMethodID( + evalue_class, + "from", + "(Ljava/lang/String;)Lorg/pytorch/executorch/EValue;"); + evalue_toTensor = env->GetMethodID( + evalue_class, "toTensor", "()Lorg/pytorch/executorch/Tensor;"); + evalue_mTypeCode = env->GetFieldID(evalue_class, "mTypeCode", "I"); + evalue_mData = env->GetFieldID(evalue_class, "mData", "Ljava/lang/Object;"); } - std::stringstream ss; - ss << "Unknown EValue typeCode: " << typeCode; - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); - return {}; + + initialized = true; } }; -class ExecuTorchJni : public facebook::jni::HybridClass { - private: - friend HybridBase; - std::unique_ptr module_; +JniCache g_jni_cache; - public: - constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/Module;"; +} // anonymous namespace - static facebook::jni::local_ref initHybrid( - facebook::jni::alias_ref, - facebook::jni::alias_ref modelPath, - jint loadMode, - jint numThreads) { - return makeCxxInstance(modelPath, loadMode, numThreads); - } +namespace executorch::extension { + +// Native module handle class - named ExecuTorchJni to match friend declaration in Module +class ExecuTorchJni { + public: + std::unique_ptr module_; ExecuTorchJni( - facebook::jni::alias_ref modelPath, + JNIEnv* env, + jstring modelPath, jint loadMode, jint numThreads) { Module::LoadMode load_mode = Module::LoadMode::Mmap; @@ -273,17 +172,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #else auto etdump_gen = nullptr; #endif - module_ = std::make_unique( - modelPath->toStdString(), load_mode, std::move(etdump_gen)); + std::string path = jstring_to_string(env, modelPath); + module_ = std::make_unique(path, load_mode, std::move(etdump_gen)); #ifdef ET_USE_THREADPOOL - // Default to using cores/2 threadpool threads. The long-term plan is to - // improve performant core detection in CPUInfo, but for now we can use - // cores/2 as a sane default. - // - // Based on testing, this is almost universally faster than using all - // cores, as efficiency cores can be quite slow. In extreme cases, using - // all cores can be 10x slower than using cores/2. auto threadpool = executorch::extension::threadpool::get_threadpool(); if (threadpool) { int thread_count = @@ -295,244 +187,534 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #endif } - facebook::jni::local_ref> execute( - facebook::jni::alias_ref methodName, - facebook::jni::alias_ref< - facebook::jni::JArrayClass::javaobject> - jinputs) { - return execute_method(methodName->toStdString(), jinputs); + // Access protected methods_ member (friend class privilege) + Method* get_method(const std::string& method_name) { + auto it = module_->methods_.find(method_name); + if (it != module_->methods_.end()) { + return it->second.method.get(); + } + return nullptr; } +}; + +} // namespace executorch::extension + +namespace { + +// Helper to create Java Tensor from native tensor +jobject newJTensorFromTensor(JNIEnv* env, const executorch::aten::Tensor& tensor) { + g_jni_cache.init(env); - jint load_method(facebook::jni::alias_ref methodName) { - return static_cast(module_->load_method(methodName->toStdString())); + const auto scalarType = tensor.scalar_type(); + if (scalar_type_to_java_dtype.count(scalarType) == 0) { + std::stringstream ss; + ss << "executorch::aten::Tensor scalar type is not supported on java side"; + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; } + int jdtype = scalar_type_to_java_dtype.at(scalarType); - facebook::jni::local_ref> execute_method( - std::string method, - facebook::jni::alias_ref< - facebook::jni::JArrayClass::javaobject> - jinputs) { - // If no inputs is given, it will run with sample inputs (ones) - if (jinputs->size() == 0) { - auto result = module_->load_method(method); - if (result != Error::Ok) { - // Format hex string - std::stringstream ss; - ss << "Cannot get method names [Native Error: 0x" << std::hex - << std::uppercase << static_cast(result) << "]"; + // Create shape array + const auto& tensor_shape = tensor.sizes(); + jlongArray jTensorShape = env->NewLongArray(tensor_shape.size()); + if (jTensorShape == nullptr) { + return nullptr; + } + std::vector shape_vec; + for (const auto& s : tensor_shape) { + shape_vec.push_back(s); + } + env->SetLongArrayRegion(jTensorShape, 0, shape_vec.size(), shape_vec.data()); + + // Create ByteBuffer wrapping tensor data + jobject jTensorBuffer = env->NewDirectByteBuffer( + const_cast(tensor.const_data_ptr()), tensor.nbytes()); + if (jTensorBuffer == nullptr) { + env->DeleteLocalRef(jTensorShape); + return nullptr; + } - jni_helper::throwExecutorchException( - static_cast(result), ss.str()); - return {}; - } - auto&& underlying_method = module_->methods_[method].method; - auto&& buf = prepare_input_tensors(*underlying_method); - result = underlying_method->execute(); - if (result != Error::Ok) { - jni_helper::throwExecutorchException( - static_cast(result), - "Execution failed for method: " + method); - return {}; - } - facebook::jni::local_ref> jresult = - facebook::jni::JArrayClass::newArray( - underlying_method->outputs_size()); - - for (int i = 0; i < underlying_method->outputs_size(); i++) { - auto jevalue = - JEValue::newJEValueFromEValue(underlying_method->get_output(i)); - jresult->setElement(i, *jevalue); - } - return jresult; - } + // Set byte order to native order + jclass byteBufferClass = env->FindClass("java/nio/ByteBuffer"); + jmethodID orderMethod = + env->GetMethodID(byteBufferClass, "order", "(Ljava/nio/ByteOrder;)Ljava/nio/ByteBuffer;"); + jclass byteOrderClass = env->FindClass("java/nio/ByteOrder"); + jmethodID nativeOrderMethod = + env->GetStaticMethodID(byteOrderClass, "nativeOrder", "()Ljava/nio/ByteOrder;"); + jobject nativeOrder = env->CallStaticObjectMethod(byteOrderClass, nativeOrderMethod); + env->CallObjectMethod(jTensorBuffer, orderMethod, nativeOrder); + + env->DeleteLocalRef(byteBufferClass); + env->DeleteLocalRef(byteOrderClass); + env->DeleteLocalRef(nativeOrder); + + // Call nativeNewTensor static method (pass 0 for nativeHandle since we don't need it) + jobject result = env->CallStaticObjectMethod( + g_jni_cache.tensor_class, + g_jni_cache.tensor_nativeNewTensor, + jTensorBuffer, + jTensorShape, + jdtype, + static_cast(0)); + + env->DeleteLocalRef(jTensorBuffer); + env->DeleteLocalRef(jTensorShape); + + return result; +} - std::vector evalues; - std::vector tensors; - - static const auto typeCodeField = - JEValue::javaClassStatic()->getField("mTypeCode"); - - for (int i = 0; i < jinputs->size(); i++) { - auto jevalue = jinputs->getElement(i); - const auto typeCode = jevalue->getFieldValue(typeCodeField); - if (typeCode == JEValue::kTypeCodeTensor) { - tensors.emplace_back(JEValue::JEValueToTensorImpl(jevalue)); - evalues.emplace_back(tensors.back()); - } else if (typeCode == JEValue::kTypeCodeInt) { - int64_t value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } else if (typeCode == JEValue::kTypeCodeDouble) { - double value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } else if (typeCode == JEValue::kTypeCodeBool) { - bool value = jevalue->getFieldValue(typeCodeField); - evalues.emplace_back(value); - } - } +// Helper to create native TensorPtr from Java Tensor +TensorPtr newTensorFromJTensor(JNIEnv* env, jobject jtensor) { + g_jni_cache.init(env); -#ifdef EXECUTORCH_ANDROID_PROFILING - auto start = std::chrono::high_resolution_clock::now(); - auto result = module_->execute(method, evalues); - auto end = std::chrono::high_resolution_clock::now(); - auto duration = - std::chrono::duration_cast(end - start) - .count(); - ET_LOG(Debug, "Execution time: %lld ms.", duration); + jint jdtype = env->CallIntMethod(jtensor, g_jni_cache.tensor_dtypeJniCode); -#else - auto result = module_->execute(method, evalues); + jlongArray jshape = + static_cast(env->GetObjectField(jtensor, g_jni_cache.tensor_shape)); -#endif + jobject jbuffer = env->CallObjectMethod(jtensor, g_jni_cache.tensor_getRawDataBuffer); - if (!result.ok()) { - jni_helper::throwExecutorchException( - static_cast(result.error()), - "Execution failed for method: " + method); - return {}; - } + jsize rank = env->GetArrayLength(jshape); + + std::vector shapeArr(rank); + env->GetLongArrayRegion(jshape, 0, rank, shapeArr.data()); + + std::vector shape_vec; + shape_vec.reserve(rank); + + int64_t numel = 1; + for (int i = 0; i < rank; i++) { + shape_vec.push_back(shapeArr[i]); + } + for (int i = rank - 1; i >= 0; --i) { + numel *= shapeArr[i]; + } - facebook::jni::local_ref> jresult = - facebook::jni::JArrayClass::newArray(result.get().size()); + if (java_dtype_to_scalar_type.count(jdtype) == 0) { + std::stringstream ss; + ss << "Unknown Tensor jdtype: [" << jdtype << "]"; + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + return nullptr; + } + + ScalarType scalar_type = java_dtype_to_scalar_type.at(jdtype); + const jlong dataCapacity = env->GetDirectBufferCapacity(jbuffer); + if (dataCapacity < 0) { + std::stringstream ss; + ss << "Tensor buffer is not direct or has invalid capacity"; + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + return nullptr; + } + + const size_t elementSize = executorch::runtime::elementSize(scalar_type); + const jlong expectedElements = static_cast(numel); + const jlong expectedBytes = expectedElements * static_cast(elementSize); + const bool matchesElements = dataCapacity == expectedElements; + const bool matchesBytes = dataCapacity == expectedBytes; + + if (!matchesElements && !matchesBytes) { + std::stringstream ss; + ss << "Tensor dimensions(elements number: " << numel + << ") inconsistent with buffer capacity " << dataCapacity + << " (element size bytes: " << elementSize << ")"; + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + return nullptr; + } + + void* data = env->GetDirectBufferAddress(jbuffer); + TensorPtr result = from_blob(data, shape_vec, scalar_type); - for (int i = 0; i < result.get().size(); i++) { - auto jevalue = JEValue::newJEValueFromEValue(result.get()[i]); - jresult->setElement(i, *jevalue); + env->DeleteLocalRef(jshape); + env->DeleteLocalRef(jbuffer); + + return result; +} + +// Helper to create Java EValue from native EValue +jobject newJEValueFromEValue(JNIEnv* env, EValue evalue) { + g_jni_cache.init(env); + + if (evalue.isTensor()) { + jobject jtensor = newJTensorFromTensor(env, evalue.toTensor()); + if (jtensor == nullptr) { + return nullptr; } - return jresult; + jobject result = env->CallStaticObjectMethod( + g_jni_cache.evalue_class, g_jni_cache.evalue_from_tensor, jtensor); + env->DeleteLocalRef(jtensor); + return result; + } else if (evalue.isInt()) { + return env->CallStaticObjectMethod( + g_jni_cache.evalue_class, g_jni_cache.evalue_from_long, evalue.toInt()); + } else if (evalue.isDouble()) { + return env->CallStaticObjectMethod( + g_jni_cache.evalue_class, g_jni_cache.evalue_from_double, evalue.toDouble()); + } else if (evalue.isBool()) { + return env->CallStaticObjectMethod( + g_jni_cache.evalue_class, + g_jni_cache.evalue_from_bool, + static_cast(evalue.toBool())); + } else if (evalue.isString()) { + std::string str = + std::string(evalue.toString().begin(), evalue.toString().end()); + jstring jstr = env->NewStringUTF(str.c_str()); + jobject result = env->CallStaticObjectMethod( + g_jni_cache.evalue_class, g_jni_cache.evalue_from_string, jstr); + env->DeleteLocalRef(jstr); + return result; } - facebook::jni::local_ref> - readLogBuffer() { - return readLogBufferUtil(); + std::stringstream ss; + ss << "Unknown EValue type: [" << static_cast(evalue.tag) << "]"; + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; +} + +// Helper to get TensorPtr from Java EValue +TensorPtr JEValueToTensorImpl(JNIEnv* env, jobject jevalue) { + g_jni_cache.init(env); + + jint typeCode = env->GetIntField(jevalue, g_jni_cache.evalue_mTypeCode); + if (typeCode == kTypeCodeTensor) { + jobject jtensor = + env->CallObjectMethod(jevalue, g_jni_cache.evalue_toTensor); + TensorPtr result = newTensorFromJTensor(env, jtensor); + env->DeleteLocalRef(jtensor); + return result; + } + + std::stringstream ss; + ss << "Unknown EValue typeCode: " << typeCode; + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + return nullptr; +} + +} // namespace + +extern "C" { + +JNIEXPORT jlong JNICALL +Java_org_pytorch_executorch_Module_nativeCreate( + JNIEnv* env, + jclass /* clazz */, + jstring modelPath, + jint loadMode, + jint numThreads) { + auto* native = new executorch::extension::ExecuTorchJni(env, modelPath, loadMode, numThreads); + return reinterpret_cast(native); +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_Module_nativeDestroy( + JNIEnv* /* env */, + jclass /* clazz */, + jlong nativeHandle) { + if (nativeHandle != 0) { + auto* native = reinterpret_cast(nativeHandle); + delete native; } +} - static facebook::jni::local_ref> - readLogBufferStatic(facebook::jni::alias_ref) { - return readLogBufferUtil(); +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeExecute( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName, + jobjectArray jinputs) { + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return nullptr; } - static facebook::jni::local_ref> - readLogBufferUtil() { -#ifdef __ANDROID__ + g_jni_cache.init(env); + + std::string method = jstring_to_string(env, methodName); + jsize inputSize = jinputs != nullptr ? env->GetArrayLength(jinputs) : 0; - facebook::jni::local_ref> ret; - - access_log_buffer([&](std::vector& buffer) { - const auto size = buffer.size(); - ret = facebook::jni::JArrayClass::newArray(size); - for (auto i = 0u; i < size; i++) { - const auto& entry = buffer[i]; - // Format the log entry as "[TIMESTAMP FUNCTION FILE:LINE] LEVEL - // MESSAGE". - std::stringstream ss; - ss << "[" << entry.timestamp << " " << entry.function << " " - << entry.filename << ":" << entry.line << "] " - << static_cast(entry.level) << " " << entry.message; - - facebook::jni::local_ref jstr_message = - facebook::jni::make_jstring(ss.str().c_str()); - (*ret)[i] = jstr_message; + // If no inputs is given, it will run with sample inputs (ones) + if (inputSize == 0) { + auto result = native->module_->load_method(method); + if (result != Error::Ok) { + std::stringstream ss; + ss << "Cannot get method names [Native Error: 0x" << std::hex + << std::uppercase << static_cast(result) << "]"; + executorch::jni_helper::throwExecutorchException( + env, static_cast(result), ss.str()); + return nullptr; + } + auto* underlying_method = native->get_method(method); + if (underlying_method == nullptr) { + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), "Method not found: " + method); + return nullptr; + } + auto&& buf = prepare_input_tensors(*underlying_method); + result = underlying_method->execute(); + if (result != Error::Ok) { + executorch::jni_helper::throwExecutorchException( + env, static_cast(result), "Execution failed for method: " + method); + return nullptr; + } + + jobjectArray jresult = + env->NewObjectArray(underlying_method->outputs_size(), g_jni_cache.evalue_class, nullptr); + + for (int i = 0; i < underlying_method->outputs_size(); i++) { + jobject jevalue = newJEValueFromEValue(env, underlying_method->get_output(i)); + env->SetObjectArrayElement(jresult, i, jevalue); + if (jevalue != nullptr) { + env->DeleteLocalRef(jevalue); } - }); + } + return jresult; + } - return ret; -#else - return facebook::jni::JArrayClass::newArray(0); -#endif + std::vector evalues; + std::vector tensors; + + for (int i = 0; i < inputSize; i++) { + jobject jevalue = env->GetObjectArrayElement(jinputs, i); + jint typeCode = env->GetIntField(jevalue, g_jni_cache.evalue_mTypeCode); + + if (typeCode == kTypeCodeTensor) { + tensors.emplace_back(JEValueToTensorImpl(env, jevalue)); + evalues.emplace_back(tensors.back()); + } else if (typeCode == kTypeCodeInt) { + jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData); + jclass longClass = env->FindClass("java/lang/Long"); + jmethodID longValue = env->GetMethodID(longClass, "longValue", "()J"); + jlong value = env->CallLongMethod(mData, longValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + env->DeleteLocalRef(longClass); + } else if (typeCode == kTypeCodeDouble) { + jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData); + jclass doubleClass = env->FindClass("java/lang/Double"); + jmethodID doubleValue = env->GetMethodID(doubleClass, "doubleValue", "()D"); + jdouble value = env->CallDoubleMethod(mData, doubleValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + env->DeleteLocalRef(doubleClass); + } else if (typeCode == kTypeCodeBool) { + jobject mData = env->GetObjectField(jevalue, g_jni_cache.evalue_mData); + jclass boolClass = env->FindClass("java/lang/Boolean"); + jmethodID boolValue = env->GetMethodID(boolClass, "booleanValue", "()Z"); + jboolean value = env->CallBooleanMethod(mData, boolValue); + evalues.emplace_back(static_cast(value)); + env->DeleteLocalRef(mData); + env->DeleteLocalRef(boolClass); + } + env->DeleteLocalRef(jevalue); } - jboolean etdump() { #ifdef EXECUTORCH_ANDROID_PROFILING - executorch::etdump::ETDumpGen* etdumpgen = - (executorch::etdump::ETDumpGen*)module_->event_tracer(); - auto etdump_data = etdumpgen->get_etdump_data(); - - if (etdump_data.buf != nullptr && etdump_data.size > 0) { - int etdump_file = - open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644); - if (etdump_file == -1) { - ET_LOG(Error, "Cannot create result.etdump error: %d", errno); - return false; - } - ssize_t bytes_written = - write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size); - if (bytes_written == -1) { - ET_LOG(Error, "Cannot write result.etdump error: %d", errno); - return false; - } else { - ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written); - } - close(etdump_file); - free(etdump_data.buf); - return true; - } else { - ET_LOG(Error, "No ETDump data available!"); - } + auto start = std::chrono::high_resolution_clock::now(); + auto result = native->module_->execute(method, evalues); + auto end = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end - start).count(); + ET_LOG(Debug, "Execution time: %lld ms.", duration); +#else + auto result = native->module_->execute(method, evalues); #endif - return false; + + if (!result.ok()) { + executorch::jni_helper::throwExecutorchException( + env, + static_cast(result.error()), + "Execution failed for method: " + method); + return nullptr; } - facebook::jni::local_ref> getMethods() { - const auto& names_result = module_->method_names(); - if (!names_result.ok()) { - // Format hex string - std::stringstream ss; - ss << "Cannot get load module [Native Error: 0x" << std::hex - << std::uppercase << static_cast(names_result.error()) - << "]"; + jobjectArray jresult = + env->NewObjectArray(result.get().size(), g_jni_cache.evalue_class, nullptr); - jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str()); - return {}; - } - const auto& methods = names_result.get(); - facebook::jni::local_ref> ret = - facebook::jni::JArrayClass::newArray(methods.size()); - int i = 0; - for (auto s : methods) { - facebook::jni::local_ref method_name = - facebook::jni::make_jstring(s.c_str()); - (*ret)[i] = method_name; - i++; + for (size_t i = 0; i < result.get().size(); i++) { + jobject jevalue = newJEValueFromEValue(env, result.get()[i]); + env->SetObjectArrayElement(jresult, i, jevalue); + if (jevalue != nullptr) { + env->DeleteLocalRef(jevalue); } - return ret; } + return jresult; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_Module_nativeLoadMethod( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName) { + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return -1; + } + std::string method = jstring_to_string(env, methodName); + return static_cast(native->module_->load_method(method)); +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeGetMethods( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle) { + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return nullptr; + } + + const auto& names_result = native->module_->method_names(); + if (!names_result.ok()) { + std::stringstream ss; + ss << "Cannot get load module [Native Error: 0x" << std::hex + << std::uppercase << static_cast(names_result.error()) << "]"; + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str()); + return nullptr; + } + + const auto& methods = names_result.get(); + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray ret = env->NewObjectArray(methods.size(), stringClass, nullptr); + + int i = 0; + for (auto s : methods) { + jstring method_name = env->NewStringUTF(s.c_str()); + env->SetObjectArrayElement(ret, i, method_name); + env->DeleteLocalRef(method_name); + i++; + } + env->DeleteLocalRef(stringClass); + return ret; +} - facebook::jni::local_ref> getUsedBackends( - facebook::jni::alias_ref methodName) { - auto methodMeta = module_->method_meta(methodName->toStdString()).get(); - std::unordered_set backends; - for (auto i = 0; i < methodMeta.num_backends(); i++) { - backends.insert(methodMeta.get_backend_name(i).get()); +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeGetUsedBackends( + JNIEnv* env, + jclass /* clazz */, + jlong nativeHandle, + jstring methodName) { + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return nullptr; + } + + std::string method = jstring_to_string(env, methodName); + auto methodMeta = native->module_->method_meta(method).get(); + std::unordered_set backends; + for (auto i = 0; i < methodMeta.num_backends(); i++) { + backends.insert(methodMeta.get_backend_name(i).get()); + } + + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray ret = env->NewObjectArray(backends.size(), stringClass, nullptr); + + int i = 0; + for (auto s : backends) { + jstring backend_name = env->NewStringUTF(s.c_str()); + env->SetObjectArrayElement(ret, i, backend_name); + env->DeleteLocalRef(backend_name); + i++; + } + env->DeleteLocalRef(stringClass); + return ret; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeReadLogBuffer( + JNIEnv* env, + jclass /* clazz */, + jlong /* nativeHandle */) { +#ifdef __ANDROID__ + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray ret = nullptr; + + access_log_buffer([&](std::vector& buffer) { + const auto size = buffer.size(); + ret = env->NewObjectArray(size, stringClass, nullptr); + for (auto i = 0u; i < size; i++) { + const auto& entry = buffer[i]; + std::stringstream ss; + ss << "[" << entry.timestamp << " " << entry.function << " " + << entry.filename << ":" << entry.line << "] " + << static_cast(entry.level) << " " << entry.message; + jstring jstr_message = env->NewStringUTF(ss.str().c_str()); + env->SetObjectArrayElement(ret, i, jstr_message); + env->DeleteLocalRef(jstr_message); } + }); + + env->DeleteLocalRef(stringClass); + return ret; +#else + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray ret = env->NewObjectArray(0, stringClass, nullptr); + env->DeleteLocalRef(stringClass); + return ret; +#endif +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_Module_nativeReadLogBufferStatic( + JNIEnv* env, + jclass clazz) { + return Java_org_pytorch_executorch_Module_nativeReadLogBuffer(env, clazz, 0); +} + +JNIEXPORT jboolean JNICALL +Java_org_pytorch_executorch_Module_nativeEtdump( + JNIEnv* /* env */, + jclass /* clazz */, + jlong nativeHandle) { +#ifdef EXECUTORCH_ANDROID_PROFILING + auto* native = reinterpret_cast(nativeHandle); + if (native == nullptr) { + return JNI_FALSE; + } - facebook::jni::local_ref> ret = - facebook::jni::JArrayClass::newArray(backends.size()); - int i = 0; - for (auto s : backends) { - facebook::jni::local_ref backend_name = - facebook::jni::make_jstring(s.c_str()); - (*ret)[i] = backend_name; - i++; + executorch::etdump::ETDumpGen* etdumpgen = + (executorch::etdump::ETDumpGen*)native->module_->event_tracer(); + auto etdump_data = etdumpgen->get_etdump_data(); + + if (etdump_data.buf != nullptr && etdump_data.size > 0) { + int etdump_file = + open("/data/local/tmp/result.etdump", O_WRONLY | O_CREAT, 0644); + if (etdump_file == -1) { + ET_LOG(Error, "Cannot create result.etdump error: %d", errno); + return JNI_FALSE; + } + ssize_t bytes_written = + write(etdump_file, (uint8_t*)etdump_data.buf, etdump_data.size); + if (bytes_written == -1) { + ET_LOG(Error, "Cannot write result.etdump error: %d", errno); + return JNI_FALSE; + } else { + ET_LOG(Info, "ETDump written %d bytes to file.", bytes_written); } - return ret; - } - - static void registerNatives() { - registerHybrid({ - makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), - makeNativeMethod("executeNative", ExecuTorchJni::execute), - makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method), - makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), - makeNativeMethod( - "readLogBufferStaticNative", ExecuTorchJni::readLogBufferStatic), - makeNativeMethod("etdump", ExecuTorchJni::etdump), - makeNativeMethod("getMethods", ExecuTorchJni::getMethods), - makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), - }); + close(etdump_file); + free(etdump_data.buf); + return JNI_TRUE; + } else { + ET_LOG(Error, "No ETDump data available!"); } -}; -} // namespace executorch::extension +#endif + return JNI_FALSE; +} + +} // extern "C" #ifdef EXECUTORCH_BUILD_LLAMA_JNI extern void register_natives_for_llm(); @@ -540,20 +722,72 @@ extern void register_natives_for_llm(); // No op if we don't build LLM void register_natives_for_llm() {} #endif -extern void register_natives_for_runtime(); #ifdef EXECUTORCH_BUILD_EXTENSION_TRAINING -extern void register_natives_for_training(); +extern void register_natives_for_training(JNIEnv* env); #else // No op if we don't build training JNI -void register_natives_for_training() {} +void register_natives_for_training(JNIEnv* /* env */) {} #endif +void register_natives_for_runtime(JNIEnv* env); + +void register_natives_for_module(JNIEnv* env) { + jclass module_class = env->FindClass("org/pytorch/executorch/Module"); + if (module_class == nullptr) { + ET_LOG(Error, "Failed to find Module class"); + env->ExceptionClear(); + return; + } + + // clang-format off + static const JNINativeMethod methods[] = { + {"nativeCreate", "(Ljava/lang/String;II)J", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeCreate)}, + {"nativeDestroy", "(J)V", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeDestroy)}, + {"nativeExecute", + "(JLjava/lang/String;[Lorg/pytorch/executorch/EValue;)[Lorg/pytorch/executorch/EValue;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeExecute)}, + {"nativeLoadMethod", "(JLjava/lang/String;)I", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeLoadMethod)}, + {"nativeGetMethods", "(J)[Ljava/lang/String;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeGetMethods)}, + {"nativeGetUsedBackends", "(JLjava/lang/String;)[Ljava/lang/String;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeGetUsedBackends)}, + {"nativeReadLogBuffer", "(J)[Ljava/lang/String;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeReadLogBuffer)}, + {"nativeReadLogBufferStatic", "()[Ljava/lang/String;", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeReadLogBufferStatic)}, + {"nativeEtdump", "(J)Z", + reinterpret_cast(Java_org_pytorch_executorch_Module_nativeEtdump)}, + }; + // clang-format on + + int num_methods = sizeof(methods) / sizeof(methods[0]); + int result = env->RegisterNatives(module_class, methods, num_methods); + if (result != JNI_OK) { + ET_LOG(Error, "Failed to register native methods for Module"); + } + + env->DeleteLocalRef(module_class); +} + JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { - return facebook::jni::initialize(vm, [] { - executorch::extension::ExecuTorchJni::registerNatives(); - register_natives_for_llm(); - register_natives_for_runtime(); - register_natives_for_training(); - }); + g_jvm = vm; + JNIEnv* env = nullptr; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + return JNI_ERR; + } + + // Initialize the JNI cache + g_jni_cache.init(env); + + // Register native methods + register_natives_for_module(env); + register_natives_for_llm(); + register_natives_for_runtime(env); + register_natives_for_training(env); + + return JNI_VERSION_1_6; } diff --git a/extension/android/jni/jni_layer_runtime.cpp b/extension/android/jni/jni_layer_runtime.cpp index 890e1d0fad9..32e7866353a 100644 --- a/extension/android/jni/jni_layer_runtime.cpp +++ b/extension/android/jni/jni_layer_runtime.cpp @@ -6,67 +6,90 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include #include +#include namespace executorch_jni { namespace runtime = ::executorch::ET_RUNTIME_NAMESPACE; -class AndroidRuntimeJni : public facebook::jni::JavaClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/ExecuTorchRuntime;"; - - static void registerNatives() { - javaClassStatic()->registerNatives({ - makeNativeMethod( - "getRegisteredOps", AndroidRuntimeJni::getRegisteredOps), - makeNativeMethod( - "getRegisteredBackends", AndroidRuntimeJni::getRegisteredBackends), - }); - } +} // namespace executorch_jni - // Returns a string array of all registered ops - static facebook::jni::local_ref> - getRegisteredOps(facebook::jni::alias_ref) { - auto kernels = runtime::get_registered_kernels(); - auto result = facebook::jni::JArrayClass::newArray(kernels.size()); +extern "C" { - for (size_t i = 0; i < kernels.size(); ++i) { - auto op = facebook::jni::make_jstring(kernels[i].name_); - result->setElement(i, op.get()); - } +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredOps( + JNIEnv* env, + jclass /* clazz */) { + auto kernels = executorch_jni::runtime::get_registered_kernels(); + + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray result = env->NewObjectArray(kernels.size(), stringClass, nullptr); - return result; + for (size_t i = 0; i < kernels.size(); ++i) { + jstring op = env->NewStringUTF(kernels[i].name_); + env->SetObjectArrayElement(result, i, op); + env->DeleteLocalRef(op); } - // Returns a string array of all registered backends - static facebook::jni::local_ref> - getRegisteredBackends(facebook::jni::alias_ref) { - int num_backends = runtime::get_num_registered_backends(); - auto result = facebook::jni::JArrayClass::newArray(num_backends); + env->DeleteLocalRef(stringClass); + return result; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredBackends( + JNIEnv* env, + jclass /* clazz */) { + int num_backends = executorch_jni::runtime::get_num_registered_backends(); - for (int i = 0; i < num_backends; ++i) { - auto name_result = runtime::get_backend_name(i); - const char* name = ""; + jclass stringClass = env->FindClass("java/lang/String"); + jobjectArray result = env->NewObjectArray(num_backends, stringClass, nullptr); - if (name_result.ok()) { - name = *name_result; - } + for (int i = 0; i < num_backends; ++i) { + auto name_result = executorch_jni::runtime::get_backend_name(i); + const char* name = ""; - auto backend_str = facebook::jni::make_jstring(name); - result->setElement(i, backend_str.get()); + if (name_result.ok()) { + name = *name_result; } - return result; + jstring backend_str = env->NewStringUTF(name); + env->SetObjectArrayElement(result, i, backend_str); + env->DeleteLocalRef(backend_str); } -}; -} // namespace executorch_jni + env->DeleteLocalRef(stringClass); + return result; +} + +} // extern "C" + +void register_natives_for_runtime(JNIEnv* env) { + jclass runtime_class = env->FindClass("org/pytorch/executorch/ExecuTorchRuntime"); + if (runtime_class == nullptr) { + ET_LOG(Error, "Failed to find ExecuTorchRuntime class"); + env->ExceptionClear(); + return; + } + + // clang-format off + static const JNINativeMethod methods[] = { + {"nativeGetRegisteredOps", "()[Ljava/lang/String;", + reinterpret_cast( + Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredOps)}, + {"nativeGetRegisteredBackends", "()[Ljava/lang/String;", + reinterpret_cast( + Java_org_pytorch_executorch_ExecuTorchRuntime_nativeGetRegisteredBackends)}, + }; + // clang-format on + + int num_methods = sizeof(methods) / sizeof(methods[0]); + int result = env->RegisterNatives(runtime_class, methods, num_methods); + if (result != JNI_OK) { + ET_LOG(Error, "Failed to register native methods for ExecuTorchRuntime"); + } -void register_natives_for_runtime() { - executorch_jni::AndroidRuntimeJni::registerNatives(); + env->DeleteLocalRef(runtime_class); } diff --git a/extension/android/jni/jni_layer_training.cpp b/extension/android/jni/jni_layer_training.cpp index 5a5e9f24d2f..0641013a993 100644 --- a/extension/android/jni/jni_layer_training.cpp +++ b/extension/android/jni/jni_layer_training.cpp @@ -12,10 +12,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -28,19 +30,97 @@ using namespace torch::executor; namespace executorch::extension { -// Forward declarations from jni_layer.cpp +// Full implementation of TensorHybrid for training module (fbjni-based) class TensorHybrid : public facebook::jni::HybridClass { public: constexpr static const char* kJavaDescriptor = "Lorg/pytorch/executorch/Tensor;"; static facebook::jni::local_ref - newJTensorFromTensor(const executorch::aten::Tensor& tensor); + newJTensorFromTensor(const executorch::aten::Tensor& tensor) { + const auto scalarType = tensor.scalar_type(); + if (scalar_type_to_java_dtype.count(scalarType) == 0) { + facebook::jni::throwNewJavaException( + "java/lang/IllegalArgumentException", + "executorch::aten::Tensor scalar type %d is not supported on java side", + static_cast(scalarType)); + } + int jdtype = scalar_type_to_java_dtype.at(scalarType); + + const auto& tensor_shape = tensor.sizes(); + std::vector tensor_shape_vec; + for (const auto& s : tensor_shape) { + tensor_shape_vec.push_back(s); + } + facebook::jni::local_ref jTensorShape = + facebook::jni::make_long_array(tensor_shape_vec.size()); + jTensorShape->setRegion( + 0, tensor_shape_vec.size(), tensor_shape_vec.data()); + + facebook::jni::local_ref jTensorBuffer = + facebook::jni::JByteBuffer::wrapBytes( + (uint8_t*)tensor.const_data_ptr(), tensor.nbytes()); + jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder()); + + static auto cls = TensorHybrid::javaClassStatic(); + static const auto jMethodNewTensor = + cls->getStaticMethod( + facebook::jni::local_ref, + facebook::jni::local_ref, + jint, + facebook::jni::local_ref)>("nativeNewTensor"); + return jMethodNewTensor( + cls, std::move(jTensorBuffer), std::move(jTensorShape), jdtype, nullptr); + } static TensorPtr newTensorFromJTensor( - facebook::jni::alias_ref jtensor); + facebook::jni::alias_ref jtensor) { + static const auto dtypeMethod = + TensorHybrid::javaClassStatic()->getMethod("dtypeJniCode"); + jint jdtype = dtypeMethod(jtensor); + + static auto shapeField = + TensorHybrid::javaClassStatic()->getField("shape"); + auto jshape = jtensor->getFieldValue(shapeField); + + static const auto dataBufferMethod = + TensorHybrid::javaClassStatic() + ->getMethod()>( + "getRawDataBuffer"); + facebook::jni::local_ref jbuffer = + dataBufferMethod(jtensor); + + const auto rank = jshape->size(); + + std::vector shapeArr(rank); + jshape->getRegion(0, rank, shapeArr.data()); + + std::vector sizes_vec; + sizes_vec.reserve(rank); + + int64_t numel = 1; + for (int i = 0; i < rank; i++) { + sizes_vec.push_back(shapeArr[i]); + } + for (int i = rank - 1; i >= 0; --i) { + numel *= shapeArr[i]; + } + + JNIEnv* jni = facebook::jni::Environment::current(); + void* dataPtr = jni->GetDirectBufferAddress(jbuffer.get()); + if (java_dtype_to_scalar_type.count(jdtype) == 0) { + facebook::jni::throwNewJavaException( + "java/lang/IllegalArgumentException", + "Unknown Tensor jdtype: %d", + jdtype); + } + + ScalarType scalarType = java_dtype_to_scalar_type.at(jdtype); + return from_blob(dataPtr, sizes_vec, scalarType); + } }; +// Full implementation of JEValue for training module (fbjni-based) class JEValue : public facebook::jni::JavaClass { public: constexpr static const char* kJavaDescriptor = @@ -53,10 +133,69 @@ class JEValue : public facebook::jni::JavaClass { constexpr static int kTypeCodeBool = 5; static facebook::jni::local_ref newJEValueFromEValue( - runtime::EValue evalue); + runtime::EValue evalue) { + if (evalue.isTensor()) { + static auto jMethodTensor = + JEValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::local_ref)>("from"); + return jMethodTensor( + JEValue::javaClassStatic(), + TensorHybrid::newJTensorFromTensor(evalue.toTensor())); + } else if (evalue.isInt()) { + static auto jMethodInt = + JEValue::javaClassStatic() + ->getStaticMethod(jlong)>( + "from"); + return jMethodInt(JEValue::javaClassStatic(), evalue.toInt()); + } else if (evalue.isDouble()) { + static auto jMethodDouble = + JEValue::javaClassStatic() + ->getStaticMethod(jdouble)>( + "from"); + return jMethodDouble(JEValue::javaClassStatic(), evalue.toDouble()); + } else if (evalue.isBool()) { + static auto jMethodBool = + JEValue::javaClassStatic() + ->getStaticMethod(jboolean)>( + "from"); + return jMethodBool(JEValue::javaClassStatic(), evalue.toBool()); + } else if (evalue.isString()) { + static auto jMethodStr = + JEValue::javaClassStatic() + ->getStaticMethod( + facebook::jni::local_ref)>("from"); + std::string str = + std::string(evalue.toString().begin(), evalue.toString().end()); + return jMethodStr( + JEValue::javaClassStatic(), facebook::jni::make_jstring(str)); + } + facebook::jni::throwNewJavaException( + "java/lang/IllegalArgumentException", + "Unknown EValue type: %d", + static_cast(evalue.tag)); + return nullptr; + } static TensorPtr JEValueToTensorImpl( - facebook::jni::alias_ref JEValue); + facebook::jni::alias_ref jevalue) { + static const auto typeCodeField = + JEValue::javaClassStatic()->getField("mTypeCode"); + const auto typeCode = jevalue->getFieldValue(typeCodeField); + if (typeCode == JEValue::kTypeCodeTensor) { + static const auto jMethodGetTensor = + JEValue::javaClassStatic() + ->getMethod()>( + "toTensor"); + auto tensor = jMethodGetTensor(jevalue); + return TensorHybrid::newTensorFromJTensor(tensor); + } + facebook::jni::throwNewJavaException( + "java/lang/IllegalArgumentException", + "Unknown EValue typeCode: %d", + typeCode); + return nullptr; + } }; class ExecuTorchTrainingJni @@ -345,7 +484,7 @@ class SGDHybrid : public facebook::jni::HybridClass { } // namespace executorch::extension // Function to register training module natives -void register_natives_for_training() { +void register_natives_for_training(JNIEnv* /* env */) { executorch::extension::ExecuTorchTrainingJni::registerNatives(); executorch::extension::SGDHybrid::registerNatives(); };