Skip to content

Commit 91800d9

Browse files
authored
Merge branch 'main' into dev_seqmse
2 parents 9e15c41 + 364f493 commit 91800d9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1612
-170
lines changed

.ci/scripts/test_model.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ test_model() {
9797
bash examples/models/llava/install_requirements.sh
9898
STRICT="--no-strict"
9999
fi
100-
if [[ "${MODEL_NAME}" == "qwen2_5" ]]; then
100+
if [[ "${MODEL_NAME}" == "qwen2_5_1_5b" ]]; then
101101
# Install requirements for export_llama
102102
bash examples/models/llama/install_requirements.sh
103103
# Test export_llm script: python3 -m extension.llm.export.export_llm.

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ jobs:
176176
- model: phi_4_mini
177177
backend: portable
178178
runner: linux.arm64.m7g.4xlarge
179-
- model: qwen2_5
179+
- model: qwen2_5_1_5b
180180
backend: portable
181181
runner: linux.arm64.2xlarge
182182
- model: llama3_2_vision_encoder

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ To get started you can:
5252

5353
- Visit the [Step by Step Tutorial](https://pytorch.org/executorch/stable/getting-started.html) to get things running locally and deploy a model to a device
5454
- Use this [Colab Notebook](https://colab.research.google.com/drive/1qpxrXC3YdJQzly3mRg-4ayYiOjC6rue3?usp=sharing) to start playing around right away
55-
- Jump straight into LLM use cases by following specific instructions for popular open-source models such as [Llama](examples/models/llama/README.md), [Qwen 3](examples/models/qwen3/README.md), [Phi-4-mini](examples/models/phi_4_mini/README.md), and [Llava](examples/models/llava/README.md)
55+
- Jump straight into LLM use cases by following specific instructions for popular open-source models such as [Llama](examples/models/llama/README.md), [Qwen 3](examples/models/qwen3/README.md), [Phi-4-mini](examples/models/phi_4_mini/README.md), [Llava](examples/models/llava/README.md), [Voxtral](examples/models/voxtral/README.md), and [LFM2](examples/models/lfm2/README.md).
5656

5757
## Feedback and Engagement
5858

backends/arm/arm_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
class ArmCompileSpecBuilder:
2424
class DebugMode(Enum):
2525
JSON = 1
26+
TOSA = 2
2627

2728
def __init__(self):
2829
self.compile_spec: List[CompileSpec] = []

backends/arm/debug/schema.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
import json
99

1010
from dataclasses import asdict, dataclass
11-
from typing import Any
11+
from typing import Any, Optional
1212

1313
import serializer.tosa_serializer as ts # type: ignore
1414
import torch
1515

16+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
17+
1618
from torch.fx.traceback import NodeSource
1719

1820

@@ -97,37 +99,52 @@ def from_node(node: torch.fx.Node) -> TorchDebugSchema:
9799
class DebugSchema:
98100
event_id: int
99101
aten_info: ATenDebugSchema
100-
tosa_info: TosaDebugSchema
102+
tosa_info: Optional[TosaDebugSchema]
101103
torch_info: TorchDebugSchema
102104

105+
def to_dict(self) -> dict[str, Any]:
106+
output = asdict(self)
107+
108+
if self.tosa_info is None:
109+
output.pop("tosa_info")
110+
111+
return output
112+
103113

104114
class DebugHook:
105-
def __init__(self) -> None:
115+
def __init__(self, debug_mode: ArmCompileSpecBuilder.DebugMode) -> None:
106116
self._debug_events: list[DebugSchema] = []
107117
self.__op_id_to_name = {}
118+
self.mode = debug_mode
108119

109120
# Build up a mapping from TOSA 1.0 operator IDs to their names
110121
for name, val in vars(ts.Op).items():
111122
self.__op_id_to_name[val] = name
112123

113-
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None:
114-
tosa_debug_info = TosaDebugSchema(
115-
node_name=str(tosa_op),
116-
operator_name=self.__op_id_to_name[tosa_op_id],
117-
operator_id=tosa_op_id,
118-
)
124+
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema:
125+
tosa_debug_info = None
126+
127+
# If the debug data is being embedded into the TOSA flatbuffer
128+
# do not collect TOSADebugSchema data, it's redundent
129+
if self.mode != ArmCompileSpecBuilder.DebugMode.TOSA:
130+
tosa_debug_info = TosaDebugSchema(
131+
node_name=str(tosa_op),
132+
operator_name=self.__op_id_to_name[tosa_op_id],
133+
operator_id=tosa_op_id,
134+
)
119135

120136
aten_debug_info = ATenDebugSchema.from_node(node)
121137
torch_debug_info = TorchDebugSchema.from_node(node)
122138

123-
self._debug_events.append(
124-
DebugSchema(
125-
event_id=len(self._debug_events),
126-
aten_info=aten_debug_info,
127-
tosa_info=tosa_debug_info,
128-
torch_info=torch_debug_info,
129-
)
139+
debug_info = DebugSchema(
140+
event_id=len(self._debug_events),
141+
aten_info=aten_debug_info,
142+
tosa_info=tosa_debug_info,
143+
torch_info=torch_debug_info,
130144
)
145+
self._debug_events.append(debug_info)
146+
147+
return debug_info
131148

132149
def serialize(self) -> str:
133-
return json.dumps([asdict(event) for event in self._debug_events], indent=4)
150+
return json.dumps([event.to_dict() for event in self._debug_events], indent=4)

backends/arm/operator_support/to_copy_support.py

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
logger = logging.getLogger(__name__)
2222

23+
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]
24+
2325

2426
@register_tosa_support_check
2527
class ToCopySupported(SupportedTOSAOperatorCheck):
@@ -33,8 +35,6 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
3335
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3436
]
3537

36-
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]
37-
3838
@staticmethod
3939
def _merge_supported_types(
4040
# pyre-ignore[11]
@@ -53,11 +53,22 @@ def _merge_supported_types(
5353
torch.int8: [torch.bool, torch.int16, torch.int32],
5454
torch.int16: [torch.bool, torch.int8, torch.int32],
5555
torch.int32: [torch.bool, torch.int8, torch.int16],
56+
torch.int64: [torch.bool, torch.int8, torch.int16, torch.int32],
5657
}
5758
SUPPORTED_FLOAT_TYPES: SupportedTypeDict = {
5859
torch.int8: [torch.float16, torch.bfloat16, torch.float32],
5960
torch.int16: [torch.float16, torch.bfloat16, torch.float32],
6061
torch.int32: [torch.float16, torch.bfloat16, torch.float32],
62+
# INT64 inputs to casts *should* be ok, since they should be rejected by
63+
# CheckInt64InputsAndOutputs if the cast can't be done AOT.
64+
torch.int64: [
65+
torch.int8,
66+
torch.int16,
67+
torch.int32,
68+
torch.float16,
69+
torch.bfloat16,
70+
torch.float32,
71+
],
6172
torch.bfloat16: [torch.int8, torch.int16, torch.int32, torch.float32],
6273
torch.float16: [torch.int8, torch.int16, torch.int32, torch.float32],
6374
torch.float32: [
@@ -71,29 +82,42 @@ def _merge_supported_types(
7182
ALL_SUPPORTED_TYPES = _merge_supported_types(
7283
SUPPORTED_INT_TYPES, SUPPORTED_FLOAT_TYPES
7384
)
74-
POSSIBLE_TYPE_CONVERSIONS = {torch.int64: torch.int32}
7585

7686
def is_node_tosa_supported(
7787
self, node: fx.Node, tosa_spec: TosaSpecification
7888
) -> bool:
79-
assert node.target in self.targets
80-
81-
supported_dtypes = (
82-
self.ALL_SUPPORTED_TYPES
83-
if tosa_spec.support_float()
84-
else self.SUPPORTED_INT_TYPES
85-
)
86-
# Take into account possible type conversions
87-
supported_dtypes.update(
88-
(k, supported_dtypes[v])
89-
for k, v in self.POSSIBLE_TYPE_CONVERSIONS.items()
90-
if v in supported_dtypes
91-
)
9289

93-
# Check input type
94-
assert len(node.all_input_nodes) == 1
90+
supported_dtypes: SupportedTypeDict = {}
91+
if tosa_spec.support_integer():
92+
supported_dtypes = self._merge_supported_types(
93+
self.SUPPORTED_INT_TYPES, supported_dtypes
94+
)
95+
if tosa_spec.support_float():
96+
supported_dtypes = self._merge_supported_types(
97+
self.SUPPORTED_FLOAT_TYPES, supported_dtypes
98+
)
99+
100+
if len(node.all_input_nodes) != 1:
101+
self.reporter.report_reject(
102+
node,
103+
(
104+
"Expected exactly one input node, "
105+
f"got {len(node.all_input_nodes)} for {node.target}."
106+
),
107+
)
108+
return False
95109
input_val = node.all_input_nodes[0].meta["val"]
96-
assert isinstance(input_val, torch._subclasses.FakeTensor)
110+
if not isinstance(input_val, torch._subclasses.FakeTensor):
111+
self.reporter.report_reject(
112+
node,
113+
(
114+
"Invalid or missing meta: expected FakeTensor input, got "
115+
f"{type(input_val).__name__} for {node.target}."
116+
),
117+
)
118+
return False
119+
120+
# Check input type
97121
input_dtype = input_val.dtype
98122
if input_dtype not in supported_dtypes:
99123
self.reporter.report_reject(
@@ -104,14 +128,24 @@ def is_node_tosa_supported(
104128

105129
# Check output type
106130
output_val = node.meta["val"]
107-
assert isinstance(output_val, torch._subclasses.FakeTensor)
131+
if not isinstance(output_val, torch._subclasses.FakeTensor):
132+
self.reporter.report_reject(
133+
node,
134+
(
135+
"Invalid or missing meta: expected FakeTensor output, got "
136+
f"{type(output_val).__name__} for {node.target}."
137+
),
138+
)
139+
return False
108140
if output_val.dtype not in supported_dtypes[input_dtype]:
109141
self.reporter.report_reject(
110142
node,
111-
f"Output dtype {output_val.dtype} is not supported in "
112-
f"{node.target} for input dtype {input_dtype}. "
113-
f"Supported output types: "
114-
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}",
143+
(
144+
f"Output dtype {output_val.dtype} is not supported in "
145+
f"{node.target} for input dtype {input_dtype}. "
146+
f"Supported output types: "
147+
f"{', '.join(str(t) for t in supported_dtypes[input_dtype])}"
148+
),
115149
)
116150
return False
117151

@@ -120,20 +154,24 @@ def is_node_tosa_supported(
120154
if node.kwargs["memory_format"] in (torch.preserve_format,):
121155
self.reporter.report_reject(
122156
node,
123-
f"Argument 'memory_format' is not supported for "
124-
f"{node.target} right now.",
157+
(
158+
"Argument 'memory_format' is not supported for "
159+
f"{node.target} right now."
160+
),
125161
)
126162
return False
127163

128164
# Check dim_order (to_dim_order_copy)
129165
if "dim_order" in node.kwargs:
130166
dim_order = node.kwargs["dim_order"]
131167
# pyre-ignore[6]
132-
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
168+
if dim_order is not None and dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
133169
self.reporter.report_reject(
134170
node,
135-
f"Argument {dim_order=} is not supported for "
136-
f"{node.target} right now.",
171+
(
172+
f"Argument {dim_order=} is not supported for "
173+
f"{node.target} right now."
174+
),
137175
)
138176
return False
139177

backends/arm/operators/node_visitor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
# pyre-unsafe
77

8+
import json
89
from typing import Any, Dict, List, Optional
910

1011
import torch
1112

13+
from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder
1214
from executorch.backends.arm.debug.schema import DebugHook
1315
from executorch.backends.arm.tosa.mapping import TosaArg
1416
from executorch.backends.arm.tosa.specification import TosaSpecification
@@ -49,20 +51,25 @@ def _serialize_operator(
4951
outputs: List[str],
5052
attributes: Optional[Any] = None,
5153
) -> None:
54+
op_location = ""
55+
if self.debug_hook:
56+
debug_info = self.debug_hook.add(
57+
node,
58+
tosa_op=outputs[0],
59+
tosa_op_id=tosa_op,
60+
)
61+
62+
if self.debug_hook.mode == ArmCompileSpecBuilder.DebugMode.TOSA:
63+
op_location = json.dumps(debug_info.to_dict())
64+
5265
tosa_graph.addOperator(
5366
tosa_op,
5467
inputs=inputs,
5568
outputs=outputs,
5669
attributes=attributes,
70+
location=op_location,
5771
)
5872

59-
if self.debug_hook:
60-
self.debug_hook.add(
61-
node,
62-
tosa_op=outputs[0],
63-
tosa_op_id=tosa_op,
64-
)
65-
6673
def define_node(
6774
self,
6875
node: torch.fx.Node,

backends/arm/scripts/run_fvp.sh

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ if [[ ${target} == *"ethos-u55"* ]]; then
9292
-C mps3_board.uart0.shutdown_on_eot=1 \
9393
-a "${elf_file}" \
9494
${data_file} \
95-
--timelimit ${timeout} 2>&1 | tee ${log_file} || true # seconds
95+
--timelimit ${timeout} 2>&1 | sed 's/\r$//' | tee ${log_file} || true # seconds
9696
echo "[${BASH_SOURCE[0]}] Simulation complete, $?"
9797
elif [[ ${target} == *"ethos-u85"* ]]; then
9898
${nobuf} ${fvp_model} \
@@ -104,13 +104,28 @@ elif [[ ${target} == *"ethos-u85"* ]]; then
104104
-C mps4_board.uart0.shutdown_on_eot=1 \
105105
-a "${elf_file}" \
106106
${data_file} \
107-
--timelimit ${timeout} 2>&1 | tee ${log_file} || true # seconds
107+
--timelimit ${timeout} 2>&1 | sed 's/\r$//' | tee ${log_file} || true # seconds
108108
echo "[${BASH_SOURCE[0]}] Simulation complete, $?"
109109
else
110110
echo "Running ${elf_file} for ${target} is not supported"
111111
exit 1
112112
fi
113113

114+
echo "Checking for a etdump in log"
115+
! grep "#\[RUN THIS\]" ${log_file} >/dev/null
116+
if [ $? != 0 ]; then
117+
echo "Found ETDump in log!"
118+
echo "#!/bin/sh" > etdump_script.sh
119+
sed -n '/^#\[RUN THIS\]$/,/^#\[END\]$/p' ${log_file} >> etdump_script.sh
120+
# You can run etdump_script.sh if you do
121+
# $ chmod a+x etdump_script.sh
122+
# $ ./etdump_script.sh
123+
# But lets not trust the script as a bad patch would run bad code on your machine
124+
grep ">etdump.bin" etdump_script.sh | cut -d\" -f2- | cut -d\" -f1 >etdump.base64
125+
base64 -d etdump.base64 >etdump.bin
126+
python3 -m devtools.inspector.inspector_cli --etdump_path etdump.bin --source_time_scale cycles --target_time_scale cycles
127+
fi
128+
114129
echo "Checking for problems in log:"
115130
! grep -E "^(F|E|\\[critical\\]|Hard fault.|Info: Simulation is stopping. Reason: CPU time has been exceeded.).*$" ${log_file}
116131
if [ $? != 0 ]; then

0 commit comments

Comments
 (0)