Skip to content

Commit 5826c87

Browse files
authored
Merge branch 'main' into op-stack
2 parents faba0c3 + 049c9fc commit 5826c87

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+2528
-296
lines changed

.ci/scripts/test_wheel_package_qnn.sh

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ PYTHON_VERSION=$1
9898
# Check wheel does NOT contain qualcomm/sdk
9999
# ----------------------------
100100
echo "Checking wheel does not contain qualcomm/sdk..."
101-
SDK_FILES=$(unzip -l "$WHEEL_FILE" | awk '{print $4}' | grep "executorch/backends/qualcomm/sdk" || true)
101+
SDK_FILES=$(unzip -l "$WHEEL_FILE" | awk '{print $4}' | grep -E "executorch/backends/qualcomm/sdk" || true)
102102
if [ -n "$SDK_FILES" ]; then
103103
echo "ERROR: Wheel package contains unexpected qualcomm/sdk files:"
104104
echo "$SDK_FILES"
@@ -111,7 +111,7 @@ fi
111111
# Check .so files in the wheel
112112
# ----------------------------
113113
echo "Checking for .so files inside the wheel..."
114-
WHEEL_SO_FILES=$(unzip -l "$WHEEL_FILE" | awk '{print $4}' | grep "executorch/backends/qualcomm/python" || true)
114+
WHEEL_SO_FILES=$(unzip -l "$WHEEL_FILE" | awk '{print $4}' | grep -E "executorch/backends/qualcomm/python" || true)
115115
if [ -z "$WHEEL_SO_FILES" ]; then
116116
echo "ERROR: No .so files found in wheel under executorch/backends/qualcomm/python"
117117
exit 1
@@ -139,8 +139,30 @@ run_core_tests () {
139139
echo "=== [$LABEL] Installing wheel & deps ==="
140140
"$PIPBIN" install --upgrade pip
141141
"$PIPBIN" install "$WHEEL_FILE"
142-
"$PIPBIN" install torch=="2.9.0.dev20250906" --index-url "https://download.pytorch.org/whl/nightly/cpu"
143-
"$PIPBIN" install --pre torchao --index-url "https://download.pytorch.org/whl/nightly/cpu"
142+
TORCH_VERSION=$(
143+
"$PYBIN" - <<'PY'
144+
import runpy
145+
module_vars = runpy.run_path("torch_pin.py")
146+
print(module_vars["TORCH_VERSION"])
147+
PY
148+
)
149+
150+
NIGHTLY_VERSION=$(
151+
"$PYBIN" - <<'PY'
152+
import runpy
153+
module_vars = runpy.run_path("torch_pin.py")
154+
print(module_vars["NIGHTLY_VERSION"])
155+
PY
156+
)
157+
echo "=== [$LABEL] Install torch==${TORCH_VERSION}.${NIGHTLY_VERSION} ==="
158+
159+
# Install torchao based on the pinned PyTorch version
160+
"$PIPBIN" install torch=="${TORCH_VERSION}.${NIGHTLY_VERSION}" --index-url "https://download.pytorch.org/whl/nightly/cpu"
161+
162+
# Install torchao based on the pinned commit from third-party/ao submodule
163+
pushd "$REPO_ROOT/third-party/ao" > /dev/null
164+
USE_CPP=0 "$PYBIN" setup.py develop
165+
popd > /dev/null
144166

145167
echo "=== [$LABEL] Import smoke tests ==="
146168
"$PYBIN" -c "import executorch; print('executorch imported successfully')"

backends/aoti/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
3636
switch (dtype) {
3737
case 6: // PyTorch's float32 dtype code
3838
return executorch::aten::ScalarType::Float;
39+
case 15: // PyTorch's bfloat16 dtype code
40+
return executorch::aten::ScalarType::BFloat16;
3941
// Future support for additional dtypes can be added here
4042
default:
4143
ET_LOG(Error, "Unsupported dtype: %d for ScalarType conversion", dtype);

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77

88
import shutil
9+
import tempfile
910
import uuid
1011
from dataclasses import asdict, dataclass
1112
from enum import Enum
@@ -415,7 +416,7 @@ def preprocess_model(
415416
mlmodel: ct.models.MLModel, model_type: MODEL_TYPE
416417
) -> PreprocessResult:
417418
identifier = "executorch_" + str(uuid.uuid4())
418-
dir_path: Path = Path("tmp") / identifier
419+
dir_path: Path = Path(tempfile.gettempdir()) / identifier
419420
model_dir_path: Path = dir_path / "lowered_module"
420421
model_spec: ct.proto.Model_pb2 = mlmodel.get_spec()
421422
logger.warning(

backends/arm/_passes/arm_pass_manager.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@
112112
from executorch.exir.pass_manager import PassManager
113113
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
114114
from torch.fx import GraphModule
115+
from torch.fx.passes.infra.pass_base import PassResult
116+
from torch.nn.modules import Module
115117

116118

117119
class ArmPassManager(PassManager):
@@ -355,3 +357,20 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
355357
self.add_pass(DecomposeMaskedFill())
356358

357359
return self._transform(graph_module)
360+
361+
def __call__(self, module: Module) -> PassResult:
362+
try:
363+
return super().__call__(module)
364+
except Exception as e:
365+
first_exception = e.__cause__ or e.__context__ or e
366+
import re
367+
368+
message = e.args[0]
369+
m = re.search(r"An error occurred when running the '([^']+)' pass", message)
370+
if m:
371+
pass_name = m.group(1)
372+
first_exception.args = (
373+
f"{pass_name}: {first_exception.args[0]}",
374+
*first_exception.args[1:],
375+
)
376+
raise first_exception

backends/arm/_passes/fuse_batchnorm2d_pass.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
create_node,
1313
get_first_fake_tensor,
1414
)
15+
from executorch.backends.arm.common.debug import get_node_debug_info
1516
from executorch.backends.transforms.utils import (
1617
create_constant_placeholder,
1718
delete_constant_placeholder,
@@ -60,8 +61,16 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
6061
input_node = node.all_input_nodes[0]
6162
is_single_user = len(input_node.users) == 1
6263
bn_weight_node, bn_bias_node, bn_mean_node, bn_var_node = node.args[1:5]
63-
assert bn_mean_node is not None, "Batchnorm mean node cannot be None."
64-
assert bn_var_node is not None, "Batchnorm var node cannot be None."
64+
if bn_mean_node is None:
65+
raise RuntimeError(
66+
"BatchNorm mean buffer missing for node: "
67+
f"{get_node_debug_info(node, graph_module)}"
68+
)
69+
if bn_var_node is None:
70+
raise RuntimeError(
71+
"BatchNorm variance buffer missing for node: "
72+
f"{get_node_debug_info(node, graph_module)}"
73+
)
6574

6675
epsilon = node.args[-1]
6776

@@ -133,14 +142,23 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # noqa: C901
133142
input_node = new_input_node
134143
else:
135144
input_weight_node, input_bias_node = input_node.args[1:3]
136-
assert (
145+
if not (
137146
isinstance(input_weight_node, Node)
138147
and input_weight_node.op == "placeholder"
139-
), "Parameter weight of convolution must be a placeholder"
140-
assert (input_bias_node is None) or (
141-
isinstance(input_weight_node, Node)
142-
and input_weight_node.op == "placeholder"
143-
), "Parameter bias of convolution must be a placeholder or None"
148+
):
149+
raise RuntimeError(
150+
"Parameter weight of convolution must be a placeholder"
151+
)
152+
if not (
153+
(input_bias_node is None)
154+
or (
155+
isinstance(input_weight_node, Node)
156+
and input_weight_node.op == "placeholder"
157+
)
158+
):
159+
raise RuntimeError(
160+
"Parameter bias of convolution must be a placeholder or None"
161+
)
144162

145163
input_weight_tensor = torch.Tensor(
146164
get_param(self.exported_program, input_weight_node)

backends/arm/arm_vela.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ def vela_bin_pack_io(prefix, data):
3434
io_elem_size = data[prefix + "_elem_size"][i]
3535
io_offset = data[prefix + "_offset"][i]
3636
io_region = data[prefix + "_region"][i]
37-
assert len(io_shape) == vela_io_shape_dims
37+
if len(io_shape) != vela_io_shape_dims:
38+
raise ValueError(
39+
f"Expected {vela_io_shape_dims}D shape, got {len(io_shape)}D"
40+
)
3841
inp_pad = io_shape.tolist()
3942
io_struct = struct.pack(
4043
"<iiiiiiiii", *inp_pad, io_elem_size, io_offset, io_region

backends/arm/common/arm_compile_spec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def validate(self):
126126

127127
def to_list(self):
128128
"""Get the ArmCompileSpec in list form."""
129-
assert self.tosa_spec
129+
if not self.tosa_spec:
130+
raise ValueError("tosa_spec must be set before calling to_list()")
130131

131132
# Always supply a TOSA version
132133
compile_spec = [

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,63 @@ def is_node_supported(
384384
return False
385385

386386
return True
387+
388+
389+
class EthosU55CastCheck(OperatorSupportBase):
390+
"""Reject unsupported casts on U55.
391+
392+
U55 does not support casting from INT32 or any casts involving BOOL. Note that
393+
casting from one dtype to the same dtype is a no-op and is supported.
394+
395+
396+
Attributes:
397+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
398+
399+
"""
400+
401+
targets = [
402+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
403+
]
404+
405+
def __init__(self, reporter: WhyNoPartitionReporter):
406+
"""Initialize the check with a reporter.
407+
408+
Args:
409+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
410+
411+
"""
412+
super().__init__()
413+
self.reporter = reporter
414+
415+
def is_node_supported(
416+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
417+
) -> bool:
418+
"""Return True if the node satisfies the cast constraints of U55.
419+
420+
Args:
421+
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
422+
node (fx.Node): FX node to check.
423+
424+
Returns:
425+
bool: True if supported; otherwise, False.
426+
427+
"""
428+
if node.target not in self.targets:
429+
return True
430+
input_dtype = get_first_fake_tensor(node.all_input_nodes[0]).dtype
431+
output_dtype = get_first_fake_tensor(node).dtype
432+
if input_dtype == output_dtype:
433+
# This is ok as this will not result in a cast
434+
return True
435+
if input_dtype in (torch.bool, torch.int32):
436+
self.reporter.report_reject(
437+
node, f"Casting from {input_dtype} is not supported on U55."
438+
)
439+
return False
440+
if output_dtype in (torch.bool,):
441+
self.reporter.report_reject(
442+
node, f"Casting to {output_dtype} is not supported on U55."
443+
)
444+
return False
445+
446+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.arm._passes.insert_table_ops import TableOps
2222
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2323
from executorch.backends.arm.operator_support.ethos_u55_support import (
24+
EthosU55CastCheck,
2425
EthosU55DtypeSupport,
2526
EthosU55NotSupported,
2627
EthosU55TransposeCheck,
@@ -141,6 +142,7 @@ def tosa_support_factory(
141142
negative_checks.append(EthosU55DtypeSupport(reporter))
142143
negative_checks.append(EthosU55TransposeCheck(reporter))
143144
negative_checks.append(EthosU55ViewCheck(reporter))
145+
negative_checks.append(EthosU55CastCheck(reporter))
144146

145147
return chain(
146148
reporter.wrap_check(

backends/arm/process_node.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,16 @@ def process_inputs_to_parameters(
106106
) from e
107107
parameter_data = get_param(edge_program, node)
108108

109-
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
109+
if not isinstance(parameter_data, torch.Tensor):
110+
raise TypeError(
111+
f"Expected parameter '{node.name}' to be a torch.Tensor, got "
112+
f"{type(parameter_data).__name__}"
113+
)
110114
parameter_values = parameter_data.detach().numpy()
111115

112116
if tosa_arg.dtype == torch.float32:
113-
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
117+
if not tosa_spec.support_float():
118+
raise ValueError(f"{tosa_spec} doesn't support float operations")
114119

115120
# Handle special case for INT48 tensors
116121
special_type = node.meta.get(TosaSpecialDtype.meta_key(), None)
@@ -142,7 +147,11 @@ def process_inputs_to_buffers(
142147
) from e
143148
buffer_data = get_buffer(edge_program, node)
144149

145-
assert isinstance(buffer_data, torch.Tensor), "Expect Attr to be tensor"
150+
if not isinstance(buffer_data, torch.Tensor):
151+
raise TypeError(
152+
f"Expected buffer '{node.name}' to be a torch.Tensor, got "
153+
f"{type(buffer_data).__name__}"
154+
)
146155
buffer_values = buffer_data.detach().numpy()
147156

148157
# TODO: fragile code for temporary fix
@@ -183,8 +192,12 @@ def process_placeholder(
183192
tosa_spec: TosaSpecification,
184193
):
185194
"""Wrapper for processing and serializing all types of placeholders"""
186-
assert node.name == node.target, "Expect placeholder name and target to match"
187-
assert 0 == len(node.args), "Can't handle default input values"
195+
if node.name != node.target:
196+
raise ValueError(
197+
f"Placeholder name '{node.name}' does not match target '{node.target}'"
198+
)
199+
if len(node.args) != 0:
200+
raise ValueError(f"Placeholder '{node.name}' must not have default values")
188201

189202
if node.name in edge_program.graph_signature.user_inputs:
190203
process_inputs(node, tosa_graph, tosa_spec)

0 commit comments

Comments
 (0)