From 42940bd0b93a5b713b31cfc4552ab947b1f6b756 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 15 May 2025 16:31:23 -0700 Subject: [PATCH 01/12] Android Java throw exception if cannot load model or tokenizer --- .../java/org/pytorch/executorch/Module.java | 5 +++ .../executorch/extension/llm/LlmModule.java | 36 +++++++++++++------ 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index f3f543dc2a8..c2b39c5756a 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -11,6 +11,8 @@ import android.util.Log; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.annotations.Experimental; @@ -52,6 +54,9 @@ public static Module load(final String modelPath, int loadMode) { if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate()); } + if (!Files.isReadable(Paths.get(modelPath))) { + throw new RuntimeException("Cannot load model path " + modelPath); + } return new Module(new NativePeer(modelPath, loadMode)); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index f845937be41..98dfe7ddf02 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -12,6 +12,8 @@ import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; +import java.nio.file.Files; +import java.nio.file.Paths; import org.pytorch.executorch.annotations.Experimental; /** @@ -41,9 +43,24 @@ public class LlmModule { private static native HybridData initHybrid( int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath); + /** + * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and + * data path. + */ + public LlmModule( + int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { + if (!Files.isReadable(Paths.get(modulePath))) { + throw new RuntimeException("Cannot load model path " + modulePath); + } + if (!Files.isReadable(Paths.get(tokenizerPath))) { + throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); + } + mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataPath); + } + /** Constructs a LLM Module for a model with given model path, tokenizer, temperature. */ public LlmModule(String modulePath, String tokenizerPath, float temperature) { - mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null); + this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, null); } /** @@ -51,23 +68,22 @@ public LlmModule(String modulePath, String tokenizerPath, float temperature) { * path. */ public LlmModule(String modulePath, String tokenizerPath, float temperature, String dataPath) { - mHybridData = initHybrid(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath); + this(MODEL_TYPE_TEXT, modulePath, tokenizerPath, temperature, dataPath); } /** Constructs a LLM Module for a model with given path, tokenizer, and temperature. */ public LlmModule(int modelType, String modulePath, String tokenizerPath, float temperature) { - mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, null); + this(modelType, modulePath, tokenizerPath, temperature, null); } /** Constructs a LLM Module for a model with the given LlmModuleConfig */ public LlmModule(LlmModuleConfig config) { - mHybridData = - initHybrid( - config.getModelType(), - config.getModulePath(), - config.getTokenizerPath(), - config.getTemperature(), - config.getDataPath()); + this( + config.getModelType(), + config.getModulePath(), + config.getTokenizerPath(), + config.getTemperature(), + config.getDataPath()); } public void resetNative() { From 246b509f22b25cd45077732f7dc892d61b2e6fb4 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 15 May 2025 16:41:17 -0700 Subject: [PATCH 02/12] Use old API --- .../src/main/java/org/pytorch/executorch/Module.java | 5 ++--- .../org/pytorch/executorch/extension/llm/LlmModule.java | 7 +++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index c2b39c5756a..4379ca0d396 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -11,8 +11,7 @@ import android.util.Log; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.nio.file.Files; -import java.nio.file.Paths; +import java.io.File; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import org.pytorch.executorch.annotations.Experimental; @@ -54,7 +53,7 @@ public static Module load(final String modelPath, int loadMode) { if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate()); } - if (!Files.isReadable(Paths.get(modelPath))) { + if (!new File(modelPath).canRead()) { throw new RuntimeException("Cannot load model path " + modelPath); } return new Module(new NativePeer(modelPath, loadMode)); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 98dfe7ddf02..384cc9b871b 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -12,8 +12,7 @@ import com.facebook.jni.annotations.DoNotStrip; import com.facebook.soloader.nativeloader.NativeLoader; import com.facebook.soloader.nativeloader.SystemDelegate; -import java.nio.file.Files; -import java.nio.file.Paths; +import java.io.File; import org.pytorch.executorch.annotations.Experimental; /** @@ -49,10 +48,10 @@ private static native HybridData initHybrid( */ public LlmModule( int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { - if (!Files.isReadable(Paths.get(modulePath))) { + if (!new File(modulePath).canRead()) { throw new RuntimeException("Cannot load model path " + modulePath); } - if (!Files.isReadable(Paths.get(tokenizerPath))) { + if (!new File(tokenizerPath).canRead()) { throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); } mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataPath); From 8caebca938b56cbf4aaaca2986022d9393707773 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 15 May 2025 16:53:55 -0700 Subject: [PATCH 03/12] raise if it's not a regular file --- .../src/main/java/org/pytorch/executorch/Module.java | 3 ++- .../org/pytorch/executorch/extension/llm/LlmModule.java | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index 4379ca0d396..cb544dd6a37 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -53,7 +53,8 @@ public static Module load(final String modelPath, int loadMode) { if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate()); } - if (!new File(modelPath).canRead()) { + File modelFile = new File(modelPath); + if (!modelFile.canRead() || !modelFile.isFile()) { throw new RuntimeException("Cannot load model path " + modelPath); } return new Module(new NativePeer(modelPath, loadMode)); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 384cc9b871b..69e302edf78 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -48,10 +48,12 @@ private static native HybridData initHybrid( */ public LlmModule( int modelType, String modulePath, String tokenizerPath, float temperature, String dataPath) { - if (!new File(modulePath).canRead()) { + File modelFile = new File(modulePath); + if (!modelFile.canRead() || !modelFile.isFile()) { throw new RuntimeException("Cannot load model path " + modulePath); } - if (!new File(tokenizerPath).canRead()) { + File tokenizerFile = new File(tokenizerPath); + if (!tokenizerFile.canRead() || !tokenizerFile.isFile()) { throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); } mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataPath); From 905ccd0b60143cfdd0bc3b286fb332a4c5ea0d41 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 15 May 2025 17:29:22 -0700 Subject: [PATCH 04/12] update test --- .../executorch/ModuleInstrumentationTest.java | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java index be6efdd67be..21b6a0610fd 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleInstrumentationTest.java @@ -96,20 +96,9 @@ public void testModuleLoadForwardExplicit() throws IOException{ assertTrue(results[0].isTensor()); } - @Test + @Test(expected = RuntimeException.class) public void testModuleLoadNonExistantFile() throws IOException{ Module module = Module.load(getTestFilePath(MISSING_FILE_NAME)); - - EValue[] results = module.forward(); - assertEquals(null, results); - } - - @Test - public void testModuleLoadMethodNonExistantFile() throws IOException{ - Module module = Module.load(getTestFilePath(MISSING_FILE_NAME)); - - int loadMethod = module.loadMethod(FORWARD_METHOD); - assertEquals(loadMethod, ACCESS_FAILED); } @Test @@ -146,11 +135,11 @@ public void testForwardOnDestroyedModule() throws IOException{ assertEquals(loadMethod, OK); module.destroy(); - + EValue[] results = module.forward(); assertEquals(0, results.length); } - + @Test public void testForwardFromMultipleThreads() throws InterruptedException, IOException { Module module = Module.load(getTestFilePath(TEST_FILE_NAME)); @@ -169,7 +158,7 @@ public void run() { assertTrue(results[0].isTensor()); completed.incrementAndGet(); } catch (InterruptedException e) { - + } } }; From 9b158042ba306256957c730ec04358c48825d3ec Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 15 May 2025 18:10:59 -0700 Subject: [PATCH 05/12] Add API to get backends required by a method --- .../java/org/pytorch/executorch/ModuleE2ETest.java | 8 ++++++++ .../main/java/org/pytorch/executorch/Module.java | 10 ++++++++++ .../java/org/pytorch/executorch/NativePeer.java | 5 +++++ extension/android/jni/jni_layer.cpp | 14 ++++++++++++++ 4 files changed, 37 insertions(+) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java index 3a033851be9..cd58c4aca81 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java @@ -8,6 +8,7 @@ package org.pytorch.executorch; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertFalse; @@ -89,6 +90,13 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc assertEquals(bananaClass, argmax(scores)); } + @Test + public void testXnnpackBackendRequired() { + Module module = Module.load(getTestFilePath(filePath)); + String[] expectedBackends = new String[] {"xnnpack"}; + assertArrayEquals(expectedBackends, module.getUsedBackends("forward")); + } + @Test public void testMv2Fp32() throws IOException, URISyntaxException { testClassification("/mv2_xnnpack_fp32.pte"); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java index cb544dd6a37..8fbbf531f65 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java @@ -137,6 +137,16 @@ public int loadMethod(String methodName) { } } + /** + * Returns the names of the methods in a certain method. + * + * @param methodName method name to query + * @return an array of backend name + */ + public String[] getUsedBackends(String methodName) { + return mNativePeer.getUsedBackends(methodName); + } + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ public String[] readLogBuffer() { return mNativePeer.readLogBuffer(); diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java index a5487a4702e..122cde1f028 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -55,6 +55,11 @@ public void resetNative() { @DoNotStrip public native int loadMethod(String methodName); + /** Return the list of backends used by a method */ + @DoNotStrip + public native String[] getUsedBackends(String methodName); + + /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ @DoNotStrip public native String[] readLogBuffer(); diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index f3c62e1d70f..2ec60f5d9a4 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -395,6 +395,19 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #endif } + facebook::jni::local_ref> + getUsedBackends(facebook::jni::alias_ref methodName) { + auto methodMeta = module_->method_meta(methodName->toStdString()).get(); + facebook::jni::local_ref> ret + = facebook::jni::JArrayClass::newArray(methodMeta.num_backends()); + for (auto i = 0; i < methodMeta.num_backends(); i++) { + facebook::jni::local_ref backend_name = + facebook::jni::make_jstring(methodMeta.get_backend_name(i).get()); + (*ret)[i] = backend_name; + } + return ret; + } + static void registerNatives() { registerHybrid({ makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid), @@ -402,6 +415,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass { makeNativeMethod("execute", ExecuTorchJni::execute), makeNativeMethod("loadMethod", ExecuTorchJni::load_method), makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer), + makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends), }); } }; From 297960f48cf08516561af0ac108720e1c4156118 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 15 May 2025 19:01:27 -0700 Subject: [PATCH 06/12] fix test --- .../androidTest/java/org/pytorch/executorch/ModuleE2ETest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java index cd58c4aca81..cf77f59d3fe 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java @@ -92,7 +92,7 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc @Test public void testXnnpackBackendRequired() { - Module module = Module.load(getTestFilePath(filePath)); + Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")); String[] expectedBackends = new String[] {"xnnpack"}; assertArrayEquals(expectedBackends, module.getUsedBackends("forward")); } From b8bf076ca4b5b0636a7e385e8990c22050e2bf29 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 15 May 2025 19:53:33 -0700 Subject: [PATCH 07/12] fix test --- .../java/org/pytorch/executorch/ModuleE2ETest.java | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java index cf77f59d3fe..56692bd4134 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java @@ -92,6 +92,11 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc @Test public void testXnnpackBackendRequired() { + File pteFile = new File(getTestFilePath("/mv3_xnnpack_fp32.pte")); + InputStream inputStream = getClass().getResourceAsStream("/mv3_xnnpack_fp32.pte"); + FileUtils.copyInputStreamToFile(inputStream, pteFile); + inputStream.close(); + Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")); String[] expectedBackends = new String[] {"xnnpack"}; assertArrayEquals(expectedBackends, module.getUsedBackends("forward")); From a833d98825a2c53a5b6e7992ab9ed0ef78f68a45 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 15 May 2025 22:17:35 -0700 Subject: [PATCH 08/12] fix test --- .../androidTest/java/org/pytorch/executorch/ModuleE2ETest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java index 56692bd4134..f38f4363481 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java @@ -91,7 +91,7 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc } @Test - public void testXnnpackBackendRequired() { + public void testXnnpackBackendRequired() throws IOException, URISyntaxException { File pteFile = new File(getTestFilePath("/mv3_xnnpack_fp32.pte")); InputStream inputStream = getClass().getResourceAsStream("/mv3_xnnpack_fp32.pte"); FileUtils.copyInputStreamToFile(inputStream, pteFile); From eaa2aff38af0e1a6a535a630eedf6c23a34fae06 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 16 May 2025 15:32:37 -0700 Subject: [PATCH 09/12] lint --- .../java/org/pytorch/executorch/ModuleE2ETest.java | 2 +- .../main/java/org/pytorch/executorch/NativePeer.java | 1 - extension/android/jni/jni_layer.cpp | 11 ++++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java index f38f4363481..444a5166d95 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java @@ -98,7 +98,7 @@ public void testXnnpackBackendRequired() throws IOException, URISyntaxException inputStream.close(); Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")); - String[] expectedBackends = new String[] {"xnnpack"}; + String[] expectedBackends = new String[] {"XnnpackBackend"}; assertArrayEquals(expectedBackends, module.getUsedBackends("forward")); } diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java index 122cde1f028..5700176261b 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java @@ -59,7 +59,6 @@ public void resetNative() { @DoNotStrip public native String[] getUsedBackends(String methodName); - /** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */ @DoNotStrip public native String[] readLogBuffer(); diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 2ec60f5d9a4..5452f904ed7 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -395,14 +395,15 @@ class ExecuTorchJni : public facebook::jni::HybridClass { #endif } - facebook::jni::local_ref> - getUsedBackends(facebook::jni::alias_ref methodName) { + facebook::jni::local_ref> getUsedBackends( + facebook::jni::alias_ref methodName) { auto methodMeta = module_->method_meta(methodName->toStdString()).get(); - facebook::jni::local_ref> ret - = facebook::jni::JArrayClass::newArray(methodMeta.num_backends()); + facebook::jni::local_ref> ret = + facebook::jni::JArrayClass::newArray( + methodMeta.num_backends()); for (auto i = 0; i < methodMeta.num_backends(); i++) { facebook::jni::local_ref backend_name = - facebook::jni::make_jstring(methodMeta.get_backend_name(i).get()); + facebook::jni::make_jstring(methodMeta.get_backend_name(i).get()); (*ret)[i] = backend_name; } return ret; From 09e0776aad18e24b08871f826e9e01e99aee1b90 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 16 May 2025 16:25:05 -0700 Subject: [PATCH 10/12] Fix test --- .../androidTest/java/org/pytorch/executorch/ModuleE2ETest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java index 444a5166d95..8667a92477f 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java @@ -98,7 +98,7 @@ public void testXnnpackBackendRequired() throws IOException, URISyntaxException inputStream.close(); Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")); - String[] expectedBackends = new String[] {"XnnpackBackend"}; + String[] expectedBackends = new String[] {"XnnpackBackend", "XnnpackBackend"}; assertArrayEquals(expectedBackends, module.getUsedBackends("forward")); } From 105511388b5d4581473745950b205a22cf2f6bcb Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 16 May 2025 16:41:52 -0700 Subject: [PATCH 11/12] fix dupe --- .../java/org/pytorch/executorch/ModuleE2ETest.java | 2 +- extension/android/jni/jni_layer.cpp | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java index 8667a92477f..444a5166d95 100644 --- a/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java +++ b/extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java @@ -98,7 +98,7 @@ public void testXnnpackBackendRequired() throws IOException, URISyntaxException inputStream.close(); Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte")); - String[] expectedBackends = new String[] {"XnnpackBackend", "XnnpackBackend"}; + String[] expectedBackends = new String[] {"XnnpackBackend"}; assertArrayEquals(expectedBackends, module.getUsedBackends("forward")); } diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 5452f904ed7..543fde1e57a 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include "jni_layer_constants.h" @@ -398,13 +399,20 @@ class ExecuTorchJni : public facebook::jni::HybridClass { facebook::jni::local_ref> getUsedBackends( facebook::jni::alias_ref methodName) { auto methodMeta = module_->method_meta(methodName->toStdString()).get(); + std::unordered_set backends; + for (auto i = 0; i < methodMeta.num_backends(); i++) { + backends.insert(methodMeta.get_backend_name(i).get()); + } + facebook::jni::local_ref> ret = facebook::jni::JArrayClass::newArray( - methodMeta.num_backends()); - for (auto i = 0; i < methodMeta.num_backends(); i++) { + backends.size()); + int i = 0; + for (auto s: backends) { facebook::jni::local_ref backend_name = - facebook::jni::make_jstring(methodMeta.get_backend_name(i).get()); + facebook::jni::make_jstring(s.c_str()); (*ret)[i] = backend_name; + i++; } return ret; } From e91bbdb0ff1212e0708b858a24d5c2e4537eb943 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Fri, 16 May 2025 17:00:27 -0700 Subject: [PATCH 12/12] lint --- extension/android/jni/jni_layer.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 543fde1e57a..a78f3801c64 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -405,10 +405,9 @@ class ExecuTorchJni : public facebook::jni::HybridClass { } facebook::jni::local_ref> ret = - facebook::jni::JArrayClass::newArray( - backends.size()); + facebook::jni::JArrayClass::newArray(backends.size()); int i = 0; - for (auto s: backends) { + for (auto s : backends) { facebook::jni::local_ref backend_name = facebook::jni::make_jstring(s.c_str()); (*ret)[i] = backend_name;