diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java index 3a033851be9..444a5166d95 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java @@ -8,6 +8,7 @@ package org.pytorch.executorch; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; @@ -89,6 +90,18 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc assertEquals(bananaClass, argmax(scores)); } + @Test + public void testXnnpackBackendRequired() throws IOException, URISyntaxException { + File pteFile = new File(getTestFilePath("/mv3_xnnpack_fp32.pte")); + InputStream inputStream = getClass().getResourceAsStream("/mv3_xnnpack_fp32.pte"); + FileUtils.copyInputStreamToFile(inputStream, pteFile); + inputStream.close(); + + Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")); + String[] expectedBackends = new String[] {"XnnpackBackend"}; + assertArrayEquals(expectedBackends, module.getUsedBackends("forward")); + } + @Test public void testMv2Fp32() throws IOException, URISyntaxException { testClassification("/mv2_xnnpack_fp32.pte"); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index d4f1e99a3c7..2fd488dd1f1 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -137,6 +137,16 @@ public int loadMethod(String methodName) { } } + /** + * Returns the names of the methods in a certain method. + * + * @param methodName method name to query + * @return an array of backend name + */ + public String[] getUsedBackends(String methodName) { + return mNativePeer.getUsedBackends(methodName); + } + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { return mNativePeer.readLogBuffer(); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java index a5487a4702e..5700176261b 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -55,6 +55,10 @@ public void resetNative() { @DoNotStrip public native int loadMethod(String methodName); + /** Return the list of backends used by a method */ + @DoNotStrip + public native String[] getUsedBackends(String methodName); + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ @DoNotStrip public native String[] readLogBuffer(); diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index f3c62e1d70f..a78f3801c64 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "jni_layer_constants.h" @@ -395,6 +396,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #endif } + facebook::jni::local_ref> getUsedBackends( + facebook::jni::alias_ref methodName) { + auto methodMeta = module_->method_meta(methodName->toStdString()).get(); + std::unordered_set backends; + for (auto i = 0; i < methodMeta.num_backends(); i++) { + backends.insert(methodMeta.get_backend_name(i).get()); + } + + facebook::jni::local_ref> ret = + facebook::jni::JArrayClass::newArray(backends.size()); + int i = 0; + for (auto s : backends) { + facebook::jni::local_ref backend_name = + facebook::jni::make_jstring(s.c_str()); + (*ret)[i] = backend_name; + i++; + } + return ret; + } + static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), @@ -402,6 +423,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer), + makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), }); } };