Skip to content

Commit 6b2b72f

Browse files
authored
Merge branch 'main' into export-D80261684
2 parents be251b4 + d7fd78b commit 6b2b72f

File tree

65 files changed

+2151
-443
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+2151
-443
lines changed

.ci/scripts/unittest-windows.ps1

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
param (
2+
[string]$editable
3+
)
4+
5+
Set-PSDebug -Trace 1
6+
$ErrorActionPreference = 'Stop'
7+
$PSNativeCommandUseErrorActionPreference = $true
8+
9+
conda create --yes --quiet -n et python=3.12
10+
conda activate et
11+
12+
# Activate the VS environment - this is required for Dynamo to work, as it uses MSVC.
13+
# There are a bunch of environment variables that it requires.
14+
# See https://learn.microsoft.com/en-us/cpp/build/building-on-the-command-line.
15+
& "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\Launch-VsDevShell.ps1" -Arch amd64
16+
17+
# Install test dependencies
18+
pip install -r .ci/docker/requirements-ci.txt
19+
20+
if ($editable -eq 'true') {
21+
install_executorch.bat --editable
22+
} else {
23+
install_executorch.bat
24+
}
25+
if ($LASTEXITCODE -ne 0) {
26+
Write-Host "Installation was unsuccessful. Exit code: $LASTEXITCODE."
27+
exit $LASTEXITCODE
28+
}
29+
30+
# Run pytest with coverage
31+
# pytest -n auto --cov=./ --cov-report=xml
32+
pytest -v --full-trace -c pytest-windows.ini
33+
if ($LASTEXITCODE -ne 0) {
34+
Write-Host "Pytest invocation was unsuccessful. Exit code: $LASTEXITCODE."
35+
exit $LASTEXITCODE
36+
}

.github/workflows/_unittest.yml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ on:
1919
required: false
2020
type: string
2121
description: Install ExecuTorch in editable mode or not.
22+
default: 'false'
2223
python-version:
2324
required: false
2425
type: string
@@ -52,3 +53,14 @@ jobs:
5253
# This is needed to get the prebuilt PyTorch wheel from S3
5354
${CONDA_RUN} --no-capture-output pip install awscli==1.37.21
5455
.ci/scripts/unittest-macos.sh --build-tool "${{ inputs.build-tool }}" --build-mode "${{ inputs.build-mode }}" --editable "${{ inputs.editable }}"
56+
57+
windows:
58+
if: ${{ inputs.build-tool == 'cmake' }}
59+
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main
60+
with:
61+
submodules: 'recursive'
62+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
63+
timeout: 120
64+
script: |
65+
conda init powershell
66+
powershell .ci/scripts/unittest-windows.ps1 -editable "${{ inputs.editable }}"

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,7 @@ xcuserdata/
6565

6666
# Android
6767
*.aar
68+
69+
# Windows
70+
*.dll
71+
*.pyd

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
url = https://github.com/google/pthreadpool.git
2828
[submodule "extension/llm/tokenizers"]
2929
path = extension/llm/tokenizers
30-
url = https://github.com/pytorch-labs/tokenizers.git
30+
url = https://github.com/meta-pytorch/tokenizers.git
3131
[submodule "kernels/optimized/third-party/eigen"]
3232
path = kernels/optimized/third-party/eigen
3333
url = https://gitlab.com/libeigen/eigen.git

backends/apple/coreml/runtime/delegate/ETCoreMLModelManager.mm

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,12 +449,14 @@ - (nullable NSURL *)compiledModelURLWithIdentifier:(NSString *)identifier
449449
case ModelAssetType::CompiledModel: {
450450
// The model is already compiled; no further action needed.
451451
// Return the existing model URL.
452+
ETCoreMLLogInfo("The model in the pte file is pre-compiled. Skipping compilation.");
452453
return modelURL;
453454
}
454455

455456
case ModelAssetType::Model: {
456457
// The model is not compiled yet.
457458
// Compile the model at the specified URL with a maximum wait time of 5 minutes.
459+
ETCoreMLLogInfo("The model in the pte file is not pre-compiled. Compiling with a 5 min timeout.");
458460
NSURL *compiledModelURL = [ETCoreMLModelCompiler compileModelAtURL:modelURL
459461
maxWaitTimeInSeconds:(5 * 60)
460462
error:error];
@@ -490,6 +492,7 @@ - (nullable ETCoreMLAsset *)compiledModelAssetWithMetadata:(const ModelMetadata&
490492
error:error];
491493
if (compiledModelURL) {
492494
// Move the compiled model to the asset manager to transfer ownership.
495+
ETCoreMLLogInfo("Storing compiled asset with identifier=%@ in the asset manager.", identifier);
493496
compiledModelAsset = [self.assetManager storeAssetAtURL:compiledModelURL withIdentifier:identifier error:error];
494497
}
495498
}];

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,43 @@
88
import logging
99
from typing import cast
1010

11+
import torch
12+
1113
from executorch.exir.dialects._ops import ops as exir_ops
1214
from executorch.exir.pass_base import ExportPass
1315

1416
logger = logging.getLogger(__name__)
1517

1618

19+
def calculate_multiples(args):
20+
input_node_or_tensor = args[0]
21+
22+
if isinstance(input_node_or_tensor, torch.fx.node.Node):
23+
input_data = input_node_or_tensor.meta["val"]
24+
else:
25+
input_data = input_node_or_tensor.data
26+
27+
input_shape = input_data.shape
28+
29+
multiples = cast(list[int], args[1])
30+
expanded_rank = len(multiples)
31+
32+
# Expanded shape is 'input_shape' front-padded with ones.
33+
padding = expanded_rank - len(input_shape)
34+
extended_shape = [
35+
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
36+
]
37+
38+
# To convert expand arg to repeat arg, non-repeated dims should have
39+
# multiples[dim] = 1. Passing -1 to expand arg means
40+
# not changing the size of that dimension.
41+
multiples = [
42+
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
43+
for i in range(expanded_rank)
44+
]
45+
return multiples
46+
47+
1748
class ConvertExpandCopyToRepeatPass(ExportPass):
1849
"""
1950
Replace expand copy with repeat since it is a repeat that can only repeat singleton dimensions.
@@ -26,23 +57,7 @@ def call_operator(self, op, args, kwargs, meta):
2657
if op != self.expand_copy:
2758
return super().call_operator(op, args, kwargs, meta)
2859

29-
input_shape = args[0].data.shape
30-
multiples = cast(list[int], args[1])
31-
expanded_rank = len(multiples)
32-
33-
# Expanded shape is 'input_shape' front-padded with ones.
34-
padding = expanded_rank - len(input_shape)
35-
extended_shape = [
36-
input_shape[i] if i >= 0 else 1 for i in range(-padding, len(input_shape))
37-
]
38-
39-
# To convert expand arg to repeat arg, non-repeated dims should have
40-
# multiples[dim] = 1. Passing -1 to expand arg means
41-
# not changing the size of that dimension.
42-
multiples = [
43-
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
44-
for i in range(expanded_rank)
45-
]
60+
multiples = calculate_multiples(args)
4661

4762
if all((x == 1 for x in multiples)):
4863
# All dimensions/repetitions occur only once. Remove node

backends/arm/_passes/remove_clone_pass.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,13 @@
66

77
# pyre-unsafe
88

9+
import logging
10+
911
from executorch.exir.dialects._ops import ops as exir_ops
1012
from executorch.exir.pass_base import ExportPass
1113

14+
logger = logging.getLogger(__name__)
15+
1216

1317
class RemoveClonePass(ExportPass):
1418
"""Remove all clones from graph_module"""
@@ -21,4 +25,10 @@ def call_operator(self, op, args, kwargs, meta):
2125
raise ValueError(
2226
f"clone operator expects exactly one argument, got {len(args)}"
2327
)
28+
29+
if "memory_format" in kwargs:
30+
logger.warning(
31+
f"Removing clone with memory_format '{kwargs['memory_format']}'."
32+
)
33+
2434
return args[0]

backends/arm/debug/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.

backends/arm/debug/schema.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import json
9+
10+
from dataclasses import asdict, dataclass
11+
from typing import Any
12+
13+
import serializer.tosa_serializer as ts # type: ignore
14+
import torch
15+
16+
from torch.fx.traceback import NodeSource
17+
18+
19+
@dataclass
20+
class TosaDebugSchema:
21+
node_name: str
22+
operator_name: str
23+
operator_id: int
24+
25+
26+
@dataclass
27+
class ATenDebugSchema:
28+
node_name: str
29+
operator_name: str
30+
31+
@staticmethod
32+
def from_node(node: torch.fx.Node) -> ATenDebugSchema:
33+
# node.target is Union[Callable[..., Any], str], so we need to access this correctly depending on the type
34+
if callable(node.target):
35+
operator_name = node.target.__name__
36+
else:
37+
operator_name = node.target
38+
39+
return ATenDebugSchema(node_name=node.name, operator_name=operator_name)
40+
41+
42+
@dataclass
43+
class TorchDebugSchema:
44+
stack_trace: list[str]
45+
node_trace: list[dict[str, Any]] | str
46+
nn_module_stack: dict[str, Any] | str
47+
torch_fn: tuple[str, str] | str
48+
49+
@staticmethod
50+
def serialize_node_trace(node_trace: list[NodeSource]) -> list[dict[str, Any]]:
51+
"""Flatten the from_node dictionary to remove nesting."""
52+
flattened = []
53+
node_stack = []
54+
55+
for n in node_trace:
56+
node_stack.append((n, -1))
57+
58+
while len(node_stack) > 0:
59+
node, parent_id = node_stack.pop()
60+
flattened.append(
61+
{
62+
"name": node.name,
63+
"target": node.target,
64+
"graph_id": node.graph_id,
65+
"pass_name": node.pass_name,
66+
"action": node._get_action_string(),
67+
"parent_graph_id": parent_id,
68+
}
69+
)
70+
71+
for n in node.from_node:
72+
node_stack.append((n, node.graph_id))
73+
74+
return flattened
75+
76+
@staticmethod
77+
def from_node(node: torch.fx.Node) -> TorchDebugSchema:
78+
node_trace: str | list[dict[str, Any]] = "No node trace available."
79+
80+
if "from_node" in node.meta:
81+
# Flatten the node_trace dictionary, so there is no nesting
82+
node_trace = TorchDebugSchema.serialize_node_trace(node.meta["from_node"])
83+
84+
return TorchDebugSchema(
85+
stack_trace=node.meta.get("stack_trace", "No stack trace available").split(
86+
"\n"
87+
),
88+
node_trace=node_trace,
89+
nn_module_stack=node.meta.get(
90+
"nn_module_stack", "No module stack trace available"
91+
),
92+
torch_fn=node.meta.get("torch_fn", "No torch_fn available"),
93+
)
94+
95+
96+
@dataclass
97+
class DebugSchema:
98+
event_id: int
99+
aten_info: ATenDebugSchema
100+
tosa_info: TosaDebugSchema
101+
torch_info: TorchDebugSchema
102+
103+
104+
class DebugHook:
105+
def __init__(self) -> None:
106+
self._debug_events: list[DebugSchema] = []
107+
self.__op_id_to_name = {}
108+
109+
# Build up a mapping from TOSA 1.0 operator IDs to their names
110+
for name, val in vars(ts.Op).items():
111+
self.__op_id_to_name[val] = name
112+
113+
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> None:
114+
tosa_debug_info = TosaDebugSchema(
115+
node_name=str(tosa_op),
116+
operator_name=self.__op_id_to_name[tosa_op_id],
117+
operator_id=tosa_op_id,
118+
)
119+
120+
aten_debug_info = ATenDebugSchema.from_node(node)
121+
torch_debug_info = TorchDebugSchema.from_node(node)
122+
123+
self._debug_events.append(
124+
DebugSchema(
125+
event_id=len(self._debug_events),
126+
aten_info=aten_debug_info,
127+
tosa_info=tosa_debug_info,
128+
torch_info=torch_debug_info,
129+
)
130+
)
131+
132+
def serialize(self) -> str:
133+
return json.dumps([asdict(event) for event in self._debug_events], indent=4)

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-unsafe
77

88
from . import ( # noqa
9+
clone_support,
910
convolution_support,
1011
embedding_support,
1112
ethos_u55_support,

0 commit comments

Comments
 (0)