Skip to content

Commit 50a6d2e

Browse files
committed
Update base for Update on "Export lora weights to sep file"
Differential Revision: [D83777195](https://our.internmc.facebook.com/intern/diff/D83777195/) [ghstack-poisoned]
2 parents b100c95 + f24351a commit 50a6d2e

File tree

188 files changed

+2868
-1948
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

188 files changed

+2868
-1948
lines changed

.ci/scripts/test_backend.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env bash
22
# Copyright (c) Meta Platforms, Inc. and affiliates.
33
# All rights reserved.
4+
# Copyright 2025 Arm Limited and/or its affiliates.
45
#
56
# This source code is licensed under the BSD-style license found in the
67
# LICENSE file in the root directory of this source tree.
@@ -58,6 +59,12 @@ fi
5859
if [[ "$FLOW" == *arm* ]]; then
5960
# Setup ARM deps.
6061
.ci/scripts/setup-arm-baremetal-tools.sh
62+
63+
if [[ "$FLOW" == *ethos_u* ]]; then
64+
# Prepare a test runner binary that can run on the Corstone-3x0 FVPs
65+
backends/arm/scripts/build_executorch.sh
66+
backends/arm/test/setup_testing.sh
67+
fi
6168
fi
6269

6370
if [[ $IS_MACOS -eq 1 ]]; then

.github/workflows/android-release-artifacts.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ jobs:
9090
fi
9191
9292
FLAVOR="${{ inputs.flavor }}"
93+
if [ ! -z "$FLAVOR" ]; then
94+
GRADLE_ARGS+=" -Dflavor=${FLAVOR}"
95+
fi
96+
9397
if [[ "$FLAVOR" == "vulkan" || -z "$FLAVOR" ]]; then
9498
curl -O https://sdk.lunarg.com/sdk/download/1.4.321.1/linux/vulkansdk-linux-x86_64-1.4.321.1.tar.xz
9599
tar xf vulkansdk-linux-x86_64-1.4.321.1.tar.xz -C /tmp

.github/workflows/test-backend-arm.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
uses: ./.github/workflows/_test_backend.yml
2424
with:
2525
backend: arm
26-
flows: '["arm_tosa"]'
26+
flows: '["arm_tosa_fp", "arm_tosa_int", "arm_ethos_u55", "arm_ethos_u85"]'
2727
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
2828
timeout: 120
2929
run-linux: true

CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
/backends/transforms @kimishpatel
1515
/backends/vulkan @SS-JIA
1616
/backends/xnnpack @digantdesai @mcr229
17+
/backends/nxp @robert-kalmar
1718

1819
/devtools @Gasoonjia
1920

@@ -33,6 +34,7 @@
3334
/examples/qualcomm @cccclai
3435
/examples/selective_build @lucylq @larryliu0820 @JacobSzwejbka
3536
/examples/xnnpack @digantdesai @mcr229
37+
/examples/nxp @robert-kalmar
3638

3739
/exir/backend @cccclai @kimishpatel @JacobSzwejbka
3840
/exir @JacobSzwejbka @larryliu0820

backends/arm/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,17 @@ runtime.python_library(
106106
"//caffe2:torch",
107107
]
108108
)
109+
runtime.python_library(
110+
name = "_factory",
111+
srcs = [
112+
"util/_factory.py"
113+
],
114+
deps = [
115+
":ethosu",
116+
":vgf",
117+
":arm_compile_spec",
118+
"//executorch/backends/arm/quantizer:lib",
119+
"//executorch/exir/backend:operator_support",
120+
"//executorch/exir/backend:compile_spec_schema",
121+
]
122+
)

backends/arm/operator_support/convolution_support.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for ``aten.convolution`` in TOSA.
6+
7+
Provide general checks and hardware-specific constraints (e.g., U55 subset) for
8+
convolution nodes prior to delegation to the TOSA backend.
9+
10+
"""
511

612
from typing import cast
713

@@ -18,15 +24,24 @@
1824

1925
@register_tosa_support_check
2026
class ConvolutionSupported(SupportedTOSAOperatorCheck):
27+
"""Provide TOSA support check for convolutions."""
28+
2129
targets = [exir_ops.edge.aten.convolution.default]
2230

2331
tosa_specs = [
2432
TosaSpecification.create_from_string("TOSA-1.0+INT"),
2533
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2634
]
2735

28-
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
36+
def is_node_tosa_supported(
37+
self, node: fx.Node, tosa_spec: TosaSpecification
38+
) -> bool:
39+
"""Return True if the node is supported by TOSA.
2940
41+
Reject transposed convolutions and convolutions with non-zero output
42+
padding. Apply additional hardware-specific constraints for U55.
43+
44+
"""
3045
# Not implemented
3146
transposed = cast(bool, node.args[6])
3247
output_padding = cast(list[int], node.args[7])
@@ -46,9 +61,19 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4661
else:
4762
return True
4863

49-
def _is_node_supported_u55(self, node: fx.Node):
50-
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""
64+
def _is_node_supported_u55(self, node: fx.Node) -> bool:
65+
"""Enforce Ethos-U55-specific constraints (Vela 4.2.0).
66+
67+
Check channel dimensions, kernel sizes, and stride/pad/dilation
68+
combinations permitted on U55.
5169
70+
Args:
71+
node (fx.Node): Convolution node to validate.
72+
73+
Returns:
74+
bool: True if supported; otherwise, False.
75+
76+
"""
5277
shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
5378
shape_out = node.meta["val"].shape
5479
kernel = cast(fx.Node, node.args[1]).meta["val"].shape
@@ -98,13 +123,17 @@ def _is_node_supported_u55(self, node: fx.Node):
98123
return True
99124

100125
def _stride_condition(self, node: fx.Node) -> bool:
101-
"""This condition is somewhat complex but boils down
102-
to not supporting stride > 3, unless we have some special conditions.
103-
This condition is a simplified, relaxed version of the hardware constraint,
104-
since the actual constraint requires information not available
105-
here (without a lot of work).
126+
"""Check a simplified stride/padding/dilation constraint.
127+
128+
Disallow strides greater than 3 unless there is no padding and the
129+
dilation is 1. For 3D convolutions, enforce ``stride_z <= 1``.
130+
131+
Args:
132+
node (fx.Node): Convolution node to evaluate.
133+
134+
Returns:
135+
bool: True if the condition is satisfied.
106136
107-
This means that we might accept ops that are not actually supported.
108137
"""
109138
strides = cast(list[int], node.args[3])
110139
has_padding = any(pad > 0 for pad in cast(list[int], node.args[4]))

backends/arm/operator_support/pool_2d_support.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide TOSA support checks for 2D pooling.
6+
7+
Validate ``avg_pool2d`` and ``max_pool2d_with_indices`` against U55 profile
8+
constraints including kernel size, stride, padding, and dimensionality.
9+
10+
"""
511

612
from typing import cast
713

@@ -20,16 +26,48 @@
2026

2127

2228
def kernel_check(kernel: tuple[int, int]) -> bool:
29+
"""Check if kernel size is within U55 constraints.
30+
31+
Checks that ``kernel_x * kernel_y`` is in ``[1, 65536]`` and
32+
``kernel_y`` is in ``[1, 256]`` as required by the U55 profile.
33+
34+
Args:
35+
kernel (tuple[int, int]): Kernel height and width ``(kh, kw)``.
36+
37+
Returns:
38+
bool: True if the kernel passes validation.
39+
40+
"""
2341
if not (1 <= kernel[0] * kernel[1] <= 65536):
2442
return False
2543
return 1 <= kernel[1] <= 256
2644

2745

2846
def stride_check(strides: tuple[int, int]) -> bool:
47+
"""Check if strides are within U55 constraints.
48+
49+
Args:
50+
strides (tuple[int, int]): Vertical and horizontal strides.
51+
52+
Returns:
53+
bool: True if each stride is in ``[1, 3]``.
54+
55+
"""
2956
return all(1 <= stride <= 3 for stride in strides)
3057

3158

3259
def dim_check(shape=torch.Size) -> bool:
60+
"""Check if non-batch dims are within U55 constraints.
61+
62+
Verifies that all dimensions except batch are in ``[1, 65536]``.
63+
64+
Args:
65+
shape (torch.Size): Input tensor shape.
66+
67+
Returns:
68+
bool: True if all checked dimensions pass.
69+
70+
"""
3371
check = True
3472
for dim in shape[1:]:
3573
check &= 1 <= dim <= 65536
@@ -38,6 +76,13 @@ def dim_check(shape=torch.Size) -> bool:
3876

3977
@register_tosa_support_check
4078
class AvgPool2dSupported(SupportedTOSAOperatorCheck):
79+
"""Provide TOSA support checks for ``aten.avg_pool2d``.
80+
81+
Applies additional constraints when targeting the U55 subset, including
82+
limits on kernel size, stride, padding behavior, and tensor ranks.
83+
84+
"""
85+
4186
targets = [
4287
exir_ops.edge.aten.avg_pool2d.default,
4388
]
@@ -48,6 +93,12 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4893
]
4994

5095
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
96+
"""Return True if ``avg_pool2d`` satisfies U55 constraints.
97+
98+
Computes the effective TOSA padding (depending on ``count_include_pad``
99+
and ``divisor_override``) and validates kernel, stride, and shape limits.
100+
101+
"""
51102
if not tosa_spec.is_U55_subset:
52103
return True
53104

@@ -115,6 +166,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
115166

116167
@register_tosa_support_check
117168
class MaxPool2dSupported(SupportedTOSAOperatorCheck):
169+
"""Provide TOSA support checks for ``aten.max_pool2d_with_indices``.
170+
171+
Applies additional constraints when targeting the U55 subset, including
172+
limits on kernel size, stride, and tensor ranks.
173+
174+
"""
175+
118176
targets = [
119177
exir_ops.edge.aten.max_pool2d_with_indices.default,
120178
]
@@ -125,6 +183,9 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
125183
]
126184

127185
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
186+
"""Return True if ``max_pool2d_with_indices`` satisfies U55
187+
constraints.
188+
"""
128189
if not tosa_spec.is_U55_subset:
129190
return True
130191

backends/arm/operators/op_repeat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def define_node(
4444
validate_valid_dtype(
4545
self.target,
4646
[inputs[0], output],
47-
[ts.DType.INT8, ts.DType.INT32, ts.DType.FP32],
47+
[ts.DType.INT8, ts.DType.INT32, ts.DType.INT16, ts.DType.FP32],
4848
output.tosa_spec,
4949
)
5050

backends/arm/quantizer/quantization_annotator.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
import operator
88
from dataclasses import dataclass
9-
from typing import Callable, List, Optional, Sequence
9+
from typing import Callable, cast, List, Optional, Sequence
1010

1111
import torch
1212
import torch.fx
@@ -137,11 +137,18 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
137137
node since histc op (in HistogramObserver) only works for values up to certain upper
138138
bound.
139139
"""
140+
HISTC_UPPER_BOUND = 3.4028235e15
140141
if node.op == "get_attr" and isinstance(node.target, str):
141142
tensor = _get_node_target(gm, node.target)
142143
# torch.histc works until this upper bound
143-
HISTC_UPPER_BOUND = 3.4028235e15
144144
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
145+
if node.op == "call_function" and node.target in (
146+
torch.ops.aten.full.default,
147+
torch.ops.aten.full,
148+
torch.ops.aten.fill_.Scalar,
149+
):
150+
fill_value = cast(float, node.args[1])
151+
return abs(fill_value) > HISTC_UPPER_BOUND
145152
return False
146153

147154

@@ -358,9 +365,6 @@ def _match_pattern(
358365
torch.ops.aten.permute_copy.default,
359366
torch.ops.aten.avg_pool2d.default,
360367
torch.ops.aten.max_pool2d.default,
361-
torch.ops.aten.full.default,
362-
torch.ops.aten.full,
363-
torch.ops.aten.fill_.Scalar,
364368
torch.ops.aten.flatten.using_ints,
365369
torch.ops.aten.dropout.default,
366370
torch.ops.aten.dropout_.default,
@@ -518,9 +522,6 @@ def any_or_hardtanh_min_zero(n: Node):
518522
]
519523
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
520524
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
521-
if not isinstance(node.args[0], Node):
522-
return None
523-
524525
input_qspec = (
525526
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
526527
if is_output_annotated(node.args[0]) # type: ignore
@@ -578,7 +579,12 @@ def any_or_hardtanh_min_zero(n: Node):
578579
),
579580
]
580581
quant_properties.quant_output = None
581-
elif node.target in [torch.ops.aten.scalar_tensor.default]:
582+
elif node.target in [
583+
torch.ops.aten.scalar_tensor.default,
584+
torch.ops.aten.full.default,
585+
torch.ops.aten.full,
586+
torch.ops.aten.fill_.Scalar,
587+
]:
582588
quant_properties.quant_inputs = []
583589
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
584590
elif node.target in [operator.getitem]:

backends/arm/test/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
27
load(":targets.bzl", "define_arm_tests")
38

@@ -58,6 +63,7 @@ runtime.python_library(
5863
"//executorch/backends/arm/quantizer:lib",
5964
"//executorch/backends/arm/tosa:mapping",
6065
"//executorch/backends/arm:vgf",
66+
"//executorch/backends/arm:_factory",
6167
"//executorch/devtools/backend_debug:delegation_info",
6268
"//executorch/exir/backend:operator_support",
6369
"fbsource//third-party/pypi/tabulate:tabulate",

0 commit comments

Comments
 (0)