Skip to content

Commit ce3dd25

Browse files
Merge branch 'main' into node_quant_metadata
2 parents 43382ab + de56c81 commit ce3dd25

File tree

130 files changed

+935
-506
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

130 files changed

+935
-506
lines changed

.ci/scripts/setup-windows-msvc.ps1

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
conda create --yes --quiet -n et python=3.12
2+
conda activate et
3+
4+
# Install cmake
5+
conda install -y cmake
6+
7+
# Activate the VS environment - this is required for MSVC to work
8+
# There are a bunch of environment variables that it requires.
9+
# See https://learn.microsoft.com/en-us/cpp/build/building-on-the-command-line.
10+
& "C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\Launch-VsDevShell.ps1" -Arch amd64
11+
12+
# Install CI requirements
13+
pip install -r .ci/docker/requirements-ci.txt
14+
15+
# Create build directory
16+
$buildDir = "cmake-out-msvc"
17+
if (Test-Path -Path $buildDir) {
18+
Remove-Item -Path $buildDir -Recurse -Force
19+
}
20+
New-Item -Path $buildDir -ItemType Directory
21+
22+
# Configure CMake with MSVC (not ClangCL) and disable custom/quantized ops
23+
cmake -S . -B $buildDir `
24+
-DCMAKE_BUILD_TYPE=Release `
25+
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=ON `
26+
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON `
27+
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON `
28+
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON `
29+
-DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON `
30+
-DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON `
31+
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON `
32+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=OFF `
33+
-DEXECUTORCH_BUILD_KERNELS_CUSTOM_AOT=OFF `
34+
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=OFF `
35+
-DEXECUTORCH_BUILD_XNNPACK=ON `
36+
-DEXECUTORCH_BUILD_EXTENSION_LLM=ON `
37+
-DEXECUTORCH_BUILD_EXTENSION_LLM_RUNNER=ON
38+
39+
if ($LASTEXITCODE -ne 0) {
40+
Write-Host "CMake configuration failed. Exit code: $LASTEXITCODE."
41+
exit $LASTEXITCODE
42+
}
43+
44+
# Build with MSVC
45+
cmake --build $buildDir --config Release -j16
46+
47+
if ($LASTEXITCODE -ne 0) {
48+
Write-Host "Build failed. Exit code: $LASTEXITCODE."
49+
exit $LASTEXITCODE
50+
}
51+
52+
Write-Host "MSVC build completed successfully!"

.github/workflows/windows-msvc.yml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Windows MSVC Build
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- release/*
8+
tags:
9+
- ciflow/trunk/*
10+
pull_request:
11+
paths:
12+
- .ci/docker/ci_commit_pins/pytorch.txt
13+
- .ci/scripts/**
14+
workflow_dispatch:
15+
16+
concurrency:
17+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
18+
cancel-in-progress: true
19+
20+
jobs:
21+
build-windows-msvc:
22+
name: build-windows-msvc
23+
uses: pytorch/test-infra/.github/workflows/windows_job.yml@main
24+
with:
25+
submodules: 'recursive'
26+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
27+
timeout: 60
28+
script: |
29+
conda init powershell
30+
powershell -Command "& {
31+
Set-PSDebug -Trace 1
32+
\$ErrorActionPreference = 'Stop'
33+
\$PSNativeCommandUseErrorActionPreference = \$true
34+
.ci/scripts/setup-windows-msvc.ps1
35+
}"

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@
8888
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
8989
from .remove_noop_pass import RemoveNoopPass # noqa
9090
from .replace_scalar_with_tensor_pass import ( # noqa
91-
ReplaceScalarWithTensorArgPassTOSABI,
92-
ReplaceScalarWithTensorArgPassTOSAMI,
91+
ReplaceScalarWithTensorByProfilePass,
9392
)
9493
from .rewrite_conv2d_pass import RewriteConv2dPass # noqa
9594
from .rewrite_matmul import RewriteMatmulPass # noqa

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import itertools
98
import operator
@@ -52,7 +51,7 @@ def _match_partition_to_node(
5251
raise RuntimeError(f"Cannot find an input node which matches, {node}.")
5352

5453
def call(self, graph_module: GraphModule) -> PassResult:
55-
matmul_partitions = get_source_partitions(
54+
matmul_partitions_map = get_source_partitions(
5655
graph_module.graph,
5756
[
5857
torch.matmul,
@@ -61,7 +60,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6160
None,
6261
)
6362
matmul_partitions = list(
64-
itertools.chain.from_iterable(matmul_partitions.values())
63+
itertools.chain.from_iterable(matmul_partitions_map.values())
6564
)
6665
matmul_targets = {
6766
exir_ops.edge.aten.bmm.default,
@@ -89,7 +88,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
8988
# Create new dq-node before matmul
9089
dq_node = create_node(
9190
graph=graph_module.graph,
92-
op_target=cast(EdgeOpOverload, input_node.target), # type: ignore[arg-type]
91+
op_target=cast(EdgeOpOverload, input_node.target),
9392
)
9493
dq_node.args = (node, *input_node.args[1:])
9594
matmul_node.replace_input_with(node, dq_node)
@@ -110,7 +109,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
110109
# Create q-node after matmul
111110
q_node = create_node(
112111
graph=graph_module.graph,
113-
op_target=cast(EdgeOpOverload, partition_output.target), # type: ignore[arg-type]
112+
op_target=cast(EdgeOpOverload, partition_output.target),
114113
)
115114
matmul_node.replace_all_uses_with(q_node)
116115
q_node.args = (matmul_node, *partition_output.args[1:])

backends/arm/_passes/arm_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import traceback
98
from abc import abstractmethod

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# pyre-unsafe
9-
108

119
from collections import defaultdict
1210

@@ -89,8 +87,7 @@
8987
QuantizeOperatorArguments,
9088
RemoveNoopPass,
9189
ReplaceInfValues,
92-
ReplaceScalarWithTensorArgPassTOSABI,
93-
ReplaceScalarWithTensorArgPassTOSAMI,
90+
ReplaceScalarWithTensorByProfilePass,
9491
RetraceFoldedDtypesPass,
9592
RewriteConv2dPass,
9693
RewriteMatmulPass,
@@ -174,7 +171,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
174171
self.add_pass(CastToInt32Pass())
175172

176173
self.add_pass(CastBoolToInt8Pass())
177-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
178175
self.add_pass(AnnotateDecomposedMatmulPass())
179176
self.add_pass(QuantizeOperatorArguments())
180177
self.add_pass(ConvertELUParamsPass())
@@ -194,7 +191,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
194191
self.add_pass(ConvertExpandCopyToRepeatPass())
195192
self.add_pass(UnsqueezeBeforeRepeatPass())
196193
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
197-
self.add_pass(DecomposeSumPass())
198194
self.add_pass(DecomposeCumsumPass(exported_program))
199195
self.add_pass(Conv1dUnsqueezePass())
200196
self.add_pass(DecomposeMaxPool2DPass())
@@ -215,10 +211,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
215211
self.add_pass(RewriteMatmulPass())
216212
self.add_pass(RewriteUpsamplePass())
217213
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
214+
self.add_pass(InsertRescaleInt32Pass())
215+
self.add_pass(DecomposeSumPass())
218216
self.add_pass(ToTosaMemoryFormatPass(exported_program))
219217
self.add_pass(RemoveNoopPass())
220218
self.add_pass(InsertRescalePass())
221-
self.add_pass(InsertRescaleInt32Pass())
222219

223220
self.validate_constraints_mandatory()
224221
return self._transform(exported_program.graph_module)
@@ -244,7 +241,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
244241
self.add_pass(DecomposeSinhPass())
245242
self.add_pass(DecomposeSignPass())
246243
self.add_pass(DecomposeDivTensorModePass())
247-
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
244+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
248245
self.add_pass(DecomposeEmbeddingPass())
249246
self.add_pass(FuseQuantizedActivationPass())
250247
self.add_pass(RemoveGetItemPass())
@@ -337,7 +334,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
337334
self.add_pass(DecomposeAddmmPass())
338335
self.add_pass(DecomposeDivTensorModePass())
339336
self.add_pass(DecomposeAddSubAlphaPass())
340-
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
337+
self.add_pass(ReplaceScalarWithTensorByProfilePass())
341338
self.add_pass(ScalarsToAttributePass())
342339
self.add_pass(DecomposeGroupNormPass())
343340
self.add_pass(DecomposeLayerNormPass())
@@ -361,7 +358,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
361358

362359
self.add_pass(ConvertMinMaxPass())
363360
self.add_pass(ReplaceInfValues())
364-
self.add_pass(DecomposeSumPass())
365361

366362
if not self.tosa_spec.is_U55_subset:
367363
# Uses where which is not supported on Ethos-U55

backends/arm/_passes/arm_pass_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
# pyre-unsafe
98

109
import traceback
1110
from inspect import isclass
@@ -14,8 +13,10 @@
1413
import torch
1514
import torch.fx
1615
from executorch.backends.arm.common.debug import get_node_debug_info
16+
from executorch.backends.arm.common.type import ensure_type
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
19+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1920

2021
from torch._export.utils import (
2122
get_buffer,
@@ -82,17 +83,18 @@ def get_param_tensor(
8283
elif is_lifted_tensor_constant(exp_prog, node):
8384
return get_lifted_tensor_constant(exp_prog, node)
8485
elif is_get_attr_node(node):
86+
target_node = ensure_type(str, node.target)
8587
# This is a hack to support both lifted and unlifted graph
8688
try:
87-
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
89+
return getattr(node.graph.owning_module, target_node)
8890
except AttributeError:
89-
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
91+
return getattr(exp_prog.graph_module, target_node)
9092
raise RuntimeError(f"unsupported param type, {node.op}.")
9193

9294

9395
def create_node(
9496
graph: torch.fx.Graph,
95-
op_target: OpOverload,
97+
op_target: OpOverload | EdgeOpOverload,
9698
args: tuple = (),
9799
kwargs: Optional[dict] = None,
98100
quantize: bool = False,

backends/arm/_passes/cast_int64_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import logging
98
from typing import Set, Type

backends/arm/_passes/convert_expand_copy_to_repeat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import logging
98
from typing import cast, Set, Type

backends/arm/_passes/convert_int64_const_ops_to_int32.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
7-
86

97
import logging
108
from typing import Set, Type

0 commit comments

Comments
 (0)