Skip to content

Commit 9b15804

Browse files
committed
Add API to get backends required by a method
1 parent 905ccd0 commit 9b15804

File tree

4 files changed

+37
-0
lines changed

4 files changed

+37
-0
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
package org.pytorch.executorch;
1010

11+
import static org.junit.Assert.assertArrayEquals;
1112
import static org.junit.Assert.assertEquals;
1213
import static org.junit.Assert.assertTrue;
1314
import static org.junit.Assert.assertFalse;
@@ -89,6 +90,13 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc
8990
assertEquals(bananaClass, argmax(scores));
9091
}
9192

93+
@Test
94+
public void testXnnpackBackendRequired() {
95+
Module module = Module.load(getTestFilePath(filePath));
96+
String[] expectedBackends = new String[] {"xnnpack"};
97+
assertArrayEquals(expectedBackends, module.getUsedBackends("forward"));
98+
}
99+
92100
@Test
93101
public void testMv2Fp32() throws IOException, URISyntaxException {
94102
testClassification("/mv2_xnnpack_fp32.pte");

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ public int loadMethod(String methodName) {
137137
}
138138
}
139139

140+
/**
141+
* Returns the names of the methods in a certain method.
142+
*
143+
* @param methodName method name to query
144+
* @return an array of backend name
145+
*/
146+
public String[] getUsedBackends(String methodName) {
147+
return mNativePeer.getUsedBackends(methodName);
148+
}
149+
140150
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
141151
public String[] readLogBuffer() {
142152
return mNativePeer.readLogBuffer();

extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public void resetNative() {
5555
@DoNotStrip
5656
public native int loadMethod(String methodName);
5757

58+
/** Return the list of backends used by a method */
59+
@DoNotStrip
60+
public native String[] getUsedBackends(String methodName);
61+
62+
5863
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
5964
@DoNotStrip
6065
public native String[] readLogBuffer();

extension/android/jni/jni_layer.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,13 +395,27 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
395395
#endif
396396
}
397397

398+
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>>
399+
getUsedBackends(facebook::jni::alias_ref<jstring> methodName) {
400+
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
401+
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret
402+
= facebook::jni::JArrayClass<jstring>::newArray(methodMeta.num_backends());
403+
for (auto i = 0; i < methodMeta.num_backends(); i++) {
404+
facebook::jni::local_ref<facebook::jni::JString> backend_name =
405+
facebook::jni::make_jstring(methodMeta.get_backend_name(i).get());
406+
(*ret)[i] = backend_name;
407+
}
408+
return ret;
409+
}
410+
398411
static void registerNatives() {
399412
registerHybrid({
400413
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
401414
makeNativeMethod("forward", ExecuTorchJni::forward),
402415
makeNativeMethod("execute", ExecuTorchJni::execute),
403416
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
404417
makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer),
418+
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
405419
});
406420
}
407421
};

0 commit comments

Comments
 (0)