Skip to content

Commit bc2c54d

Browse files
author
ssjia
committed
Update on "[ET-VK][ez] Fix partitioner logic of finding keepdim arg of reduce ops"
Title says it all. For reduce ops, their signature are not all alike so some extra legwork needs to be done to identify specific arguments that need to be checked. Also included a small update to partitioner logging to improve debuggability. Differential Revision: [D80741737](https://our.internmc.facebook.com/intern/diff/D80741737/) [ghstack-poisoned]
2 parents 05a62d0 + 9e32293 commit bc2c54d

File tree

28 files changed

+563
-244
lines changed

28 files changed

+563
-244
lines changed

examples/models/llava/runner/llava_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ET_EXPERIMENTAL LlavaRunner {
4242
const float temperature = 0.8f)
4343
: temperature_(temperature),
4444
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
45-
io_manager_(std::make_unique<IOManager>()),
45+
io_manager_(std::make_unique<IOManager>(*module_)),
4646
tokenizer_path_(tokenizer_path) {
4747
ET_LOG(
4848
Info,

examples/models/qwen3/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ python -m extension.llm.export.export_llm \
4545
### Example run
4646
With ExecuTorch pybindings:
4747
```
48-
python -m examples.models.llama.runner.native
48+
python -m examples.models.llama.runner.native \
4949
--model qwen3_0_6b \
5050
--pte qwen3_0_6b.pte \
5151
--tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json \
@@ -59,9 +59,9 @@ python -m examples.models.llama.runner.native
5959

6060
With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your computer to validate](../llama/README.md#step-3-run-on-your-computer-to-validate) to build the runner):
6161
```
62-
cmake-out/examples/models/llama/llama_main
63-
--model_path qwen3_0_6b.pte
64-
--tokenizer_path ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json
62+
cmake-out/examples/models/llama/llama_main \
63+
--model_path qwen3_0_6b.pte \
64+
--tokenizer_path ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json \
6565
--prompt="Who is the president of the US?"
6666
```
6767

extension/android/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ non_fbcode_target(_kind = fb_android_library,
1010
"executorch_android/src/main/java/org/pytorch/executorch/DType.java",
1111
"executorch_android/src/main/java/org/pytorch/executorch/EValue.java",
1212
"executorch_android/src/main/java/org/pytorch/executorch/ExecuTorchRuntime.java",
13+
"executorch_android/src/main/java/org/pytorch/executorch/ExecutorchRuntimeException.java",
1314
"executorch_android/src/main/java/org/pytorch/executorch/MethodMetadata.java",
1415
"executorch_android/src/main/java/org/pytorch/executorch/Module.java",
1516
"executorch_android/src/main/java/org/pytorch/executorch/Tensor.java",

extension/android/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ executorch_target_link_options_shared_lib(executorch)
7171

7272
add_library(
7373
executorch_jni SHARED jni/jni_layer.cpp jni/log.cpp jni/jni_layer_runtime.cpp
74+
jni/jni_helper.cpp
7475
)
7576

7677
set(link_libraries)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch;
10+
11+
import java.util.Collections;
12+
import java.util.HashMap;
13+
import java.util.Map;
14+
15+
public class ExecutorchRuntimeException extends RuntimeException {
16+
// Error code constants - keep in sync with runtime/core/error.h
17+
// System errors
18+
public static final int OK = 0x00;
19+
public static final int INTERNAL = 0x01;
20+
public static final int INVALID_STATE = 0x02;
21+
public static final int END_OF_METHOD = 0x03;
22+
23+
// Logical errors
24+
public static final int NOT_SUPPORTED = 0x10;
25+
public static final int NOT_IMPLEMENTED = 0x11;
26+
public static final int INVALID_ARGUMENT = 0x12;
27+
public static final int INVALID_TYPE = 0x13;
28+
public static final int OPERATOR_MISSING = 0x14;
29+
public static final int REGISTRATION_EXCEEDING_MAX_KERNELS = 0x15;
30+
public static final int REGISTRATION_ALREADY_REGISTERED = 0x16;
31+
32+
// Resource errors
33+
public static final int NOT_FOUND = 0x20;
34+
public static final int MEMORY_ALLOCATION_FAILED = 0x21;
35+
public static final int ACCESS_FAILED = 0x22;
36+
public static final int INVALID_PROGRAM = 0x23;
37+
public static final int INVALID_EXTERNAL_DATA = 0x24;
38+
public static final int OUT_OF_RESOURCES = 0x25;
39+
40+
// Delegate errors
41+
public static final int DELEGATE_INVALID_COMPATIBILITY = 0x30;
42+
public static final int DELEGATE_MEMORY_ALLOCATION_FAILED = 0x31;
43+
public static final int DELEGATE_INVALID_HANDLE = 0x32;
44+
45+
private static final Map<Integer, String> ERROR_CODE_MESSAGES;
46+
47+
static {
48+
Map<Integer, String> map = new HashMap<>();
49+
50+
// System errors
51+
map.put(OK, "Operation successful");
52+
map.put(INTERNAL, "Internal error");
53+
map.put(INVALID_STATE, "Invalid state");
54+
map.put(END_OF_METHOD, "End of method reached");
55+
// Logical errors
56+
map.put(NOT_SUPPORTED, "Operation not supported");
57+
map.put(NOT_IMPLEMENTED, "Operation not implemented");
58+
map.put(INVALID_ARGUMENT, "Invalid argument");
59+
map.put(INVALID_TYPE, "Invalid type");
60+
map.put(OPERATOR_MISSING, "Operator missing");
61+
map.put(REGISTRATION_EXCEEDING_MAX_KERNELS, "Exceeded max kernels");
62+
map.put(REGISTRATION_ALREADY_REGISTERED, "Kernel already registered");
63+
// Resource errors
64+
map.put(NOT_FOUND, "Resource not found");
65+
map.put(MEMORY_ALLOCATION_FAILED, "Memory allocation failed");
66+
map.put(ACCESS_FAILED, "Access failed");
67+
map.put(INVALID_PROGRAM, "Invalid program");
68+
map.put(INVALID_EXTERNAL_DATA, "Invalid external data");
69+
map.put(OUT_OF_RESOURCES, "Out of resources");
70+
// Delegate errors
71+
map.put(DELEGATE_INVALID_COMPATIBILITY, "Delegate invalid compatibility");
72+
map.put(DELEGATE_MEMORY_ALLOCATION_FAILED, "Delegate memory allocation failed");
73+
map.put(DELEGATE_INVALID_HANDLE, "Delegate invalid handle");
74+
ERROR_CODE_MESSAGES = Collections.unmodifiableMap(map);
75+
}
76+
77+
static class ErrorHelper {
78+
static String formatMessage(int errorCode, String details) {
79+
String baseMessage = ERROR_CODE_MESSAGES.get(errorCode);
80+
if (baseMessage == null) {
81+
baseMessage = "Unknown error code 0x" + Integer.toHexString(errorCode);
82+
}
83+
return "[Executorch Error 0x"
84+
+ Integer.toHexString(errorCode)
85+
+ "] "
86+
+ baseMessage
87+
+ ": "
88+
+ details;
89+
}
90+
}
91+
92+
private final int errorCode;
93+
94+
public ExecutorchRuntimeException(int errorCode, String details) {
95+
super(ErrorHelper.formatMessage(errorCode, details));
96+
this.errorCode = errorCode;
97+
}
98+
99+
public int getErrorCode() {
100+
return errorCode;
101+
}
102+
103+
// Idiomatic Java exception for invalid arguments.
104+
public static class ExecutorchInvalidArgumentException extends IllegalArgumentException {
105+
private final int errorCode = INVALID_ARGUMENT;
106+
107+
public ExecutorchInvalidArgumentException(String details) {
108+
super(ErrorHelper.formatMessage(INVALID_ARGUMENT, details));
109+
}
110+
111+
public int getErrorCode() {
112+
return errorCode;
113+
}
114+
}
115+
116+
// Factory method to create an exception of the appropriate subclass.
117+
public static RuntimeException makeExecutorchException(int errorCode, String details) {
118+
switch (errorCode) {
119+
case INVALID_ARGUMENT:
120+
return new ExecutorchInvalidArgumentException(details);
121+
default:
122+
return new ExecutorchRuntimeException(errorCode, details);
123+
}
124+
}
125+
}

extension/android/jni/BUCK

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ load(":build_defs.bzl", "ET_JNI_COMPILER_FLAGS")
77

88
oncall("executorch")
99

10+
# Define the common JNI source files
11+
shared_srcs = [
12+
"jni_layer.cpp",
13+
"jni_layer_runtime.cpp",
14+
"jni_helper.cpp",
15+
"log.cpp",
16+
]
17+
1018
non_fbcode_target(_kind = executorch_generated_lib,
1119
name = "generated_op_lib_optimized",
1220
custom_ops_aten_kernel_deps = [
@@ -28,7 +36,7 @@ non_fbcode_target(_kind = executorch_generated_lib,
2836

2937
non_fbcode_target(_kind = fb_android_cxx_library,
3038
name = "executorch_jni",
31-
srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"],
39+
srcs = shared_srcs,
3240
allow_jni_merging = False,
3341
compiler_flags = ET_JNI_COMPILER_FLAGS,
3442
soname = "libexecutorch.$(ext)",
@@ -49,7 +57,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,
4957

5058
non_fbcode_target(_kind = fb_android_cxx_library,
5159
name = "executorch_jni_full",
52-
srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp"],
60+
srcs = shared_srcs,
5361
allow_jni_merging = False,
5462
compiler_flags = ET_JNI_COMPILER_FLAGS,
5563
soname = "libexecutorch.$(ext)",
@@ -71,7 +79,7 @@ non_fbcode_target(_kind = fb_android_cxx_library,
7179

7280
non_fbcode_target(_kind = fb_android_cxx_library,
7381
name = "executorch_training_jni",
74-
srcs = ["jni_layer.cpp", "log.cpp", "jni_layer_runtime.cpp", "jni_layer_training.cpp"],
82+
srcs = shared_srcs + ["jni_layer_training.cpp"],
7583
allow_jni_merging = False,
7684
compiler_flags = ET_JNI_COMPILER_FLAGS + [
7785
"-DEXECUTORCH_BUILD_EXTENSION_TRAINING",
@@ -98,11 +106,9 @@ non_fbcode_target(_kind = fb_android_cxx_library,
98106

99107
non_fbcode_target(_kind = fb_android_cxx_library,
100108
name = "executorch_llama_jni",
101-
srcs = [
102-
"jni_layer.cpp",
103-
"jni_layer_llama.cpp",
104-
"jni_layer_runtime.cpp",
105-
],
109+
exclude_files = ["log.cpp"]
110+
shared_srcs_filtered = [f for f in shared_srcs if f not in exclude_files]
111+
srcs = shared_srcs_filtered + ["jni_layer_llama.cpp"]
106112
allow_jni_merging = False,
107113
compiler_flags = ET_JNI_COMPILER_FLAGS + [
108114
"-DEXECUTORCH_BUILD_LLAMA_JNI",
@@ -145,6 +151,10 @@ runtime.export_file(
145151
name = "jni_layer_runtime.cpp",
146152
)
147153

154+
runtime.export_file(
155+
name = "jni_helper.cpp",
156+
)
157+
148158
runtime.cxx_library(
149159
name = "jni_headers",
150160
exported_headers = [
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "jni_helper.h"
10+
11+
namespace executorch::jni_helper {
12+
13+
void throwExecutorchException(uint32_t errorCode, const std::string& details) {
14+
// Get the current JNI environment
15+
auto env = facebook::jni::Environment::current();
16+
17+
// Find the Java ExecutorchRuntimeException class
18+
static auto exceptionClass = facebook::jni::findClassLocal(
19+
"org/pytorch/executorch/ExecutorchRuntimeException");
20+
21+
// Find the static factory method: makeExecutorchException(int, String)
22+
static auto makeExceptionMethod = exceptionClass->getStaticMethod<
23+
facebook::jni::local_ref<facebook::jni::JThrowable>(
24+
int, facebook::jni::alias_ref<facebook::jni::JString>)>(
25+
"makeExecutorchException",
26+
"(ILjava/lang/String;)Lorg/pytorch/executorch/ExecutorchRuntimeException;");
27+
28+
auto jDetails = facebook::jni::make_jstring(details);
29+
// Call the factory method to create the exception object
30+
auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails);
31+
facebook::jni::throwNewJavaException(exception.get());
32+
}
33+
34+
} // namespace executorch::jni_helper

extension/android/jni/jni_helper.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <fbjni/fbjni.h>
12+
#include <string>
13+
14+
namespace executorch::jni_helper {
15+
16+
/**
17+
* Throws a Java ExecutorchRuntimeException corresponding to the given error
18+
* code and details. Uses the Java factory method
19+
* ExecutorchRuntimeException.makeExecutorchException(int, String).
20+
*
21+
* @param errorCode The error code from the C++ Executorch runtime.
22+
* @param details Additional details to include in the exception message.
23+
*/
24+
void throwExecutorchException(uint32_t errorCode, const std::string& details);
25+
26+
} // namespace executorch::jni_helper

0 commit comments

Comments
 (0)