Skip to content

Commit f6c803f

Browse files
authored
Merge branch 'main' into export-D85704977
2 parents 0ea115c + fd4eb9d commit f6c803f

28 files changed

+365
-268
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
e6f766c7d750d40603eee3f66c5915bac606b3ea
1+
556fc09a9f67f24ca5591ec049c5d0c347c5f62a

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ jobs:
347347
elif [[ ${{ matrix.os}} == "zephyr-preset" ]]; then
348348
setup_script_args="--target-toolchain zephyr"
349349
toolchain_prefix=arm-zephyr-eabi-
350-
threshold="135240" # 132 KiB
350+
threshold="135656" # 132 KiB
351351
toolchain_cmake=examples/zephyr/x86_64-linux-arm-zephyr-eabi-gcc.cmake
352352
else
353353
echo "Fail unsupport OS selection ${{ matrix.os }}"

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def calculate_multiples(args):
23+
"""Returns expand args converted to repeat args, and whether the expand changes the rank"""
2324
input_node_or_tensor = args[0]
2425

2526
if isinstance(input_node_or_tensor, torch.fx.node.Node):
@@ -45,7 +46,7 @@ def calculate_multiples(args):
4546
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
4647
for i in range(expanded_rank)
4748
]
48-
return multiples
49+
return multiples, expanded_rank != len(input_shape)
4950

5051

5152
class ConvertExpandCopyToRepeatPass(ArmPass):
@@ -62,9 +63,9 @@ def call_operator(self, op, args, kwargs, meta):
6263
if op != self.expand_copy:
6364
return super().call_operator(op, args, kwargs, meta)
6465

65-
multiples = calculate_multiples(args)
66+
multiples, changes_rank = calculate_multiples(args)
6667

67-
if all((x == 1 for x in multiples)):
68+
if all((x == 1 for x in multiples)) and not changes_rank:
6869
# All dimensions/repetitions occur only once. Remove node
6970
# altogether since it's in practice just a copy.
7071
logger.warning("Found redundant expand node (no-op). Removing it.")

backends/arm/_passes/decompose_embedding_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .arm_pass_utils import create_node, get_first_fake_tensor
1818

1919
logger = logging.getLogger(__name__)
20-
logger.setLevel(logging.WARNING)
2120

2221

2322
class DecomposeEmbeddingPass(ArmPass):

backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class InsertInt32CastsAfterInt64PlaceholdersPass(ArmPass):
3636
# Key: op overload; Value: zero-based indices of positional args that must be i64.
3737
I64_INPUT_ARG_POSITIONS = {
3838
torch.ops.aten.one_hot.default: (0,),
39+
torch.ops.aten.index_copy_.default: (2,),
40+
torch.ops.aten.index_copy.default: (2,),
3941
}
4042

4143
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):

backends/arm/ethosu/backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _compile_tosa_flatbuffer(
6363
binary = vela_compile(
6464
tosa_flatbuffer,
6565
compile_flags,
66-
verbose=logger.getEffectiveLevel() == logging.INFO,
66+
verbose=logger.getEffectiveLevel() <= logging.INFO,
6767
intermediate_path=compile_spec.get_intermediate_path(),
6868
)
6969
return binary

backends/arm/operator_support/right_shift_support.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,5 @@ def is_node_tosa_supported(
4848
"""
4949
# TODO MLETORCH-525 Remove warning
5050
if tosa_spec.is_U55_subset:
51-
logging.warning(f"{node.target} may introduce one-off errors.")
51+
logger.warning(f"{node.target} may introduce one-off errors.")
5252
return True

backends/arm/operator_support/slice_copy_support.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for ``aten.slice_copy`` in TOSA.
56
7+
Support slicing with unit step only; emit a warning and reject otherwise.
8+
9+
"""
610

711
import logging
812

@@ -19,19 +23,29 @@
1923

2024
@register_tosa_support_check
2125
class SliceCopySupported(SupportedTOSAOperatorCheck):
26+
"""Provide TOSA support check for ``aten.slice_copy``."""
27+
2228
targets = [exir_ops.edge.aten.slice_copy.Tensor]
2329

2430
tosa_specs = [
2531
TosaSpecification.create_from_string("TOSA-1.0+INT"),
2632
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2733
]
2834

29-
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc]
35+
def is_node_tosa_supported(
36+
self, node: fx.Node, tosa_spec: TosaSpecification
37+
) -> bool: # type: ignore[override, misc]
38+
"""Return True if the node is supported by TOSA.
39+
40+
Accept slice_copy when the step is 1 (or unspecified). Warn and reject
41+
non-unit step sizes.
42+
43+
"""
3044
if tosa_spec not in self.tosa_specs:
3145
return False
3246

3347
args = node.args
3448
if len(args) == 5 and (step := args[4]) != 1:
35-
logging.warning(f"{node.target} with step size of {step} not supported.")
49+
logger.warning(f"{node.target} with step size of {step} not supported.")
3650
return False
3751
return True

backends/arm/test/conftest.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
import logging
76
import os
87
import random
9-
import sys
108
from typing import Any
119

1210
import pytest
@@ -29,8 +27,6 @@ def pytest_configure(config):
2927
if config.option.arm_run_tosa_version:
3028
pytest._test_options["tosa_version"] = config.option.arm_run_tosa_version
3129

32-
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
33-
3430

3531
def pytest_collection_modifyitems(config, items):
3632
pass

backends/arm/test/misc/test_debug_feats.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
)
2424
from executorch.backends.test.harness.stages import StageType
2525

26-
2726
input_t1 = Tuple[torch.Tensor] # Input x
2827

2928

@@ -261,14 +260,14 @@ def test_dump_tosa_debug_tosa(test_data: input_t1):
261260

262261

263262
@common.parametrize("test_data", Linear.inputs)
264-
def test_dump_tosa_ops(caplog, test_data: input_t1):
263+
def test_dump_tosa_ops(capsys, test_data: input_t1):
265264
aten_ops: list[str] = []
266265
exir_ops: list[str] = []
267266
pipeline = TosaPipelineINT[input_t1](Linear(), test_data, aten_ops, exir_ops)
268267
pipeline.pop_stage("run_method_and_compare_outputs")
269268
pipeline.dump_operator_distribution("to_edge_transform_and_lower")
270269
pipeline.run()
271-
assert "TOSA operators:" in caplog.text
270+
assert "TOSA operators:" in capsys.readouterr().out
272271

273272

274273
class Add(torch.nn.Module):
@@ -282,12 +281,15 @@ def forward(self, x):
282281

283282
@common.parametrize("test_data", Add.inputs)
284283
@common.XfailIfNoCorstone300
285-
def test_fail_dump_tosa_ops(caplog, test_data: input_t1):
284+
def test_fail_dump_tosa_ops(capsys, test_data: input_t1):
286285
aten_ops: list[str] = []
287286
exir_ops: list[str] = []
288287
pipeline = EthosU55PipelineINT[input_t1](
289288
Add(), test_data, aten_ops, exir_ops, use_to_edge_transform_and_lower=True
290289
)
291290
pipeline.dump_operator_distribution("to_edge_transform_and_lower")
292291
pipeline.run()
293-
assert "Can not get operator distribution for Vela command stream." in caplog.text
292+
assert (
293+
"Can not get operator distribution for Vela command stream."
294+
in capsys.readouterr().out
295+
)

0 commit comments

Comments
 (0)