Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

package org.pytorch.executorch;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;
Expand Down Expand Up @@ -89,6 +90,18 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc
assertEquals(bananaClass, argmax(scores));
}

@Test
public void testXnnpackBackendRequired() throws IOException, URISyntaxException {
File pteFile = new File(getTestFilePath("/mv3_xnnpack_fp32.pte"));
InputStream inputStream = getClass().getResourceAsStream("/mv3_xnnpack_fp32.pte");
FileUtils.copyInputStreamToFile(inputStream, pteFile);
inputStream.close();

Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"));
String[] expectedBackends = new String[] {"XnnpackBackend"};
assertArrayEquals(expectedBackends, module.getUsedBackends("forward"));
}

@Test
public void testMv2Fp32() throws IOException, URISyntaxException {
testClassification("/mv2_xnnpack_fp32.pte");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,16 @@ public int loadMethod(String methodName) {
}
}

/**
* Returns the names of the methods in a certain method.
*
* @param methodName method name to query
* @return an array of backend name
*/
public String[] getUsedBackends(String methodName) {
return mNativePeer.getUsedBackends(methodName);
}

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
public String[] readLogBuffer() {
return mNativePeer.readLogBuffer();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ public void resetNative() {
@DoNotStrip
public native int loadMethod(String methodName);

/** Return the list of backends used by a method */
@DoNotStrip
public native String[] getUsedBackends(String methodName);

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
@DoNotStrip
public native String[] readLogBuffer();
Expand Down
22 changes: 22 additions & 0 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "jni_layer_constants.h"
Expand Down Expand Up @@ -395,13 +396,34 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
#endif
}

facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
facebook::jni::alias_ref<jstring> methodName) {
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
std::unordered_set<std::string> backends;
for (auto i = 0; i < methodMeta.num_backends(); i++) {
backends.insert(methodMeta.get_backend_name(i).get());
}

facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
facebook::jni::JArrayClass<jstring>::newArray(backends.size());
int i = 0;
for (auto s : backends) {
facebook::jni::local_ref<facebook::jni::JString> backend_name =
facebook::jni::make_jstring(s.c_str());
(*ret)[i] = backend_name;
i++;
}
return ret;
}

static void registerNatives() {
registerHybrid({
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
makeNativeMethod("forward", ExecuTorchJni::forward),
makeNativeMethod("execute", ExecuTorchJni::execute),
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer),
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
});
}
};
Expand Down
Loading