Skip to content
1 change: 1 addition & 0 deletions extension/android/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ non_fbcode_target(_kind = fb_android_library,
srcs = [
"executorch_android/src/main/java/org/pytorch/executorch/DType.java",
"executorch_android/src/main/java/org/pytorch/executorch/EValue.java",
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",
"executorch_android/src/main/java/org/pytorch/executorch/annotations/Experimental.java",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class ModuleE2ETest {

val module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"))
val expectedBackends = arrayOf("XnnpackBackend")
Assert.assertArrayEquals(expectedBackends, module.getUsedBackends("forward"))
Assert.assertArrayEquals(expectedBackends, module.getMethodMetadata("forward").getBackends())
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ class ModuleInstrumentationTest {
Assert.assertTrue(results[0].isTensor)
}

@Test
@Throws(IOException::class, URISyntaxException::class)
fun testMethodMetadata() {
val module = Module.load(getTestFilePath(TEST_FILE_NAME))

Assert.assertArrayEquals(arrayOf("forward"), module.getMethods())
Assert.assertTrue(module.getMethodMetadata("forward").backends.isEmpty())
}

@Test
@Throws(IOException::class)
fun testModuleLoadMethodAndForward() {
Expand Down Expand Up @@ -91,7 +100,7 @@ class ModuleInstrumentationTest {
Assert.assertEquals(loadMethod.toLong(), INVALID_ARGUMENT.toLong())
}

@Test
@Test(expected = RuntimeException::class)
@Throws(IOException::class)
fun testNonPteFile() {
val module = Module.load(getTestFilePath(NON_PTE_FILE_NAME))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

package org.pytorch.executorch;

/** Helper class to access the metadata for a method from a Module */
public class MethodMetadata {
private String mName;

private String[] mBackends;

MethodMetadata setName(String name) {
mName = name;
return this;
}

/**
* @return Method name
*/
public String getName() {
return mName;
}

MethodMetadata setBackends(String[] backends) {
mBackends = backends;
return this;
}

/**
* @return Backends used for this method
*/
public String[] getBackends() {
return mBackends;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import com.facebook.soloader.nativeloader.NativeLoader;
import com.facebook.soloader.nativeloader.SystemDelegate;
import java.io.File;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.pytorch.executorch.annotations.Experimental;
Expand Down Expand Up @@ -48,12 +50,27 @@ public class Module {

private final HybridData mHybridData;

private final Map<String, MethodMetadata> mMethodMetadata;

@DoNotStrip
private static native HybridData initHybrid(
String moduleAbsolutePath, int loadMode, int initHybrid);

private Module(String moduleAbsolutePath, int loadMode, int numThreads) {
mHybridData = initHybrid(moduleAbsolutePath, loadMode, numThreads);

mMethodMetadata = populateMethodMeta();
}

Map<String, MethodMetadata> populateMethodMeta() {
String[] methods = getMethods();
Map<String, MethodMetadata> metadata = new HashMap<String, MethodMetadata>();
for (int i = 0; i < methods.length; i++) {
String name = methods[i];
metadata.put(name, new MethodMetadata().setName(name).setBackends(getUsedBackends(name)));
}

return metadata;
}

/** Lock protecting the non-thread safe methods in mHybridData. */
Expand Down Expand Up @@ -158,13 +175,34 @@ public int loadMethod(String methodName) {
private native int loadMethodNative(String methodName);

/**
* Returns the names of the methods in a certain method.
* Returns the names of the backends in a certain method.
*
* @param methodName method name to query
* @return an array of backend name
*/
@DoNotStrip
public native String[] getUsedBackends(String methodName);
private native String[] getUsedBackends(String methodName);

/**
* Returns the names of methods.
*
* @return name of methods in this Module
*/
@DoNotStrip
public native String[] getMethods();

/**
* Get the corresponding @MethodMetadata for a method
*
* @param name method name
* @return @MethodMetadata for this method
*/
public MethodMetadata getMethodMetadata(String name) {
if (!mMethodMetadata.containsKey(name)) {
throw new RuntimeException("method " + name + "does not exist for this module");
}
return mMethodMetadata.get(name);
}

/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
public String[] readLogBuffer() {
Expand Down
21 changes: 21 additions & 0 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,26 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
return false;
}

facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getMethods() {
const auto& names_result = module_->method_names();
if (!names_result.ok()) {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"Cannot get load module");
}
const auto& methods = names_result.get();
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
facebook::jni::JArrayClass<jstring>::newArray(methods.size());
int i = 0;
for (auto s : methods) {
facebook::jni::local_ref<facebook::jni::JString> method_name =
facebook::jni::make_jstring(s.c_str());
(*ret)[i] = method_name;
i++;
}
return ret;
}

facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
facebook::jni::alias_ref<jstring> methodName) {
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
Expand Down Expand Up @@ -458,6 +478,7 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
makeNativeMethod("loadMethodNative", ExecuTorchJni::load_method),
makeNativeMethod("readLogBufferNative", ExecuTorchJni::readLogBuffer),
makeNativeMethod("etdump", ExecuTorchJni::etdump),
makeNativeMethod("getMethods", ExecuTorchJni::getMethods),
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
});
}
Expand Down
Loading