diff --git a/extension/android/BUCK b/extension/android/BUCK index 0fad11eb677..78ee57aae90 100644 --- a/extension/android/BUCK +++ b/extension/android/BUCK @@ -8,6 +8,7 @@ non_fbcode_target(_kind = fb_android_library, srcs = [ "executorch_android/src/main/java/org/pytorch/executorch/DType.java", "executorch_android/src/main/java/org/pytorch/executorch/EValue.java", + "executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java", "executorch_android/src/main/java/org/pytorch/executorch/Module.java", "executorch_android/src/main/java/org/pytorch/executorch/Tensor.java", "executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java", diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt index 7f633bac8b9..b496d9718c3 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.kt @@ -69,7 +69,7 @@ class ModuleE2ETest { val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")) val expectedBackends = arrayOf("XnnpackBackend") - Assert.assertArrayEquals(expectedBackends, module.getUsedBackends("forward")) + Assert.assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").getBackends()) } @Test diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt index b94e804ff4f..dd3f4b880a6 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.kt @@ -55,6 +55,15 @@ class ModuleInstrumentationTest { Assert.assertTrue(results[0].isTensor) } + @Test + @Throws(IOException::class, URISyntaxException::class) + fun testMethodMetadata() { + val module = Module.load(getTestFilePath(TEST_FILE_NAME)) + + Assert.assertArrayEquals(arrayOf("forward"), module.getMethods()) + Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty()) + } + @Test @Throws(IOException::class) fun testModuleLoadMethodAndForward() { @@ -91,7 +100,7 @@ class ModuleInstrumentationTest { Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong()) } - @Test + @Test(expected = RuntimeException::class) @Throws(IOException::class) fun testNonPteFile() { val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME)) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java new file mode 100644 index 00000000000..b2dde35a2d8 --- /dev/null +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +/** Helper class to access the metadata for a method from a Module */ +public class MethodMetadata { + private String mName; + + private String[] mBackends; + + MethodMetadata setName(String name) { + mName = name; + return this; + } + + /** + * @return Method name + */ + public String getName() { + return mName; + } + + MethodMetadata setBackends(String[] backends) { + mBackends = backends; + return this; + } + + /** + * @return Backends used for this method + */ + public String[] getBackends() { + return mBackends; + } +} diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index bbfd3254111..a68c180aa82 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -14,6 +14,8 @@ import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; import java.io.File; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.annotations.Experimental; @@ -48,12 +50,27 @@ public class Module { private final HybridData mHybridData; + private final Map mMethodMetadata; + @DoNotStrip private static native HybridData initHybrid( String moduleAbsolutePath, int loadMode, int initHybrid); private Module(String moduleAbsolutePath, int loadMode, int numThreads) { mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads); + + mMethodMetadata = populateMethodMeta(); + } + + Map populateMethodMeta() { + String[] methods = getMethods(); + Map metadata = new HashMap(); + for (int i = 0; i < methods.length; i++) { + String name = methods[i]; + metadata.put(name, new MethodMetadata().setName(name).setBackends(getUsedBackends(name))); + } + + return metadata; } /** Lock protecting the non-thread safe methods in mHybridData. */ @@ -158,13 +175,34 @@ public int loadMethod(String methodName) { private native int loadMethodNative(String methodName); /** - * Returns the names of the methods in a certain method. + * Returns the names of the backends in a certain method. * * @param methodName method name to query * @return an array of backend name */ @DoNotStrip - public native String[] getUsedBackends(String methodName); + private native String[] getUsedBackends(String methodName); + + /** + * Returns the names of methods. + * + * @return name of methods in this Module + */ + @DoNotStrip + public native String[] getMethods(); + + /** + * Get the corresponding @MethodMetadata for a method + * + * @param name method name + * @return @MethodMetadata for this method + */ + public MethodMetadata getMethodMetadata(String name) { + if (!mMethodMetadata.containsKey(name)) { + throw new RuntimeException("method " + name + "does not exist for this module"); + } + return mMethodMetadata.get(name); + } /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 936593abee8..bbe47e98a06 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -431,6 +431,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass { return false; } + facebook::jni::local_ref> getMethods() { + const auto& names_result = module_->method_names(); + if (!names_result.ok()) { + facebook::jni::throwNewJavaException( + facebook::jni::gJavaLangIllegalArgumentException, + "Cannot get load module"); + } + const auto& methods = names_result.get(); + facebook::jni::local_ref> ret = + facebook::jni::JArrayClass::newArray(methods.size()); + int i = 0; + for (auto s : methods) { + facebook::jni::local_ref method_name = + facebook::jni::make_jstring(s.c_str()); + (*ret)[i] = method_name; + i++; + } + return ret; + } + facebook::jni::local_ref> getUsedBackends( facebook::jni::alias_ref methodName) { auto methodMeta = module_->method_meta(methodName->toStdString()).get(); @@ -458,6 +478,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method), makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer), makeNativeMethod("etdump", ExecuTorchJni::etdump), + makeNativeMethod("getMethods", ExecuTorchJni::getMethods), makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), }); }