Skip to content

Commit 27b61b3

Browse files
author
pytorchbot
committed
2024-11-12 nightly release (dc41596)
1 parent a453011 commit 27b61b3

File tree

82 files changed

+2724
-256
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+2724
-256
lines changed

.ci/scripts/test_llama_runner_eager.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ run_and_verify() {
4242
-d fp32 \
4343
--max_seq_length 32 \
4444
--temperature 0 \
45+
--show_tokens \
4546
--prompt "Once upon a time," > result.txt
4647

4748
# Verify result.txt
4849
RESULT=$(cat result.txt)
49-
EXPECTED_RESULT="there was a little girl"
50+
EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263"
5051
if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then
5152
echo "Actual result: ${RESULT}"
5253
echo "Success"

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@
6464
[submodule "third-party/pybind11"]
6565
path = third-party/pybind11
6666
url = https://github.com/pybind/pybind11.git
67+
[submodule "third-party/ao"]
68+
path = third-party/ao
69+
url = https://github.com/pytorch/ao.git

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ We recommend using the latest release tag from the
4343
See [CONTRIBUTING.md](CONTRIBUTING.md) for details about issues, PRs, code
4444
style, CI jobs, and other development topics.
4545

46+
To connect with us and other community members, we invite you to join PyTorch Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
47+
* Head to the `#executorch-general` channel for general questions, discussion, and community support.
48+
* Join the `#executorch-contributors` channel if you're interested in contributing directly to project development.
49+
50+
4651
## Directory Structure
4752

4853
```

backends/arm/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,18 @@ python_library(
7070
],
7171
)
7272

73+
python_library(
74+
name = "tosa_specification",
75+
srcs = [
76+
"tosa_specification.py",
77+
],
78+
typing = True,
79+
deps = [
80+
"fbsource//third-party/pypi/packaging:packaging",
81+
"//executorch/exir/backend:compile_spec_schema",
82+
],
83+
)
84+
7385
python_library(
7486
name = "tosa_utils",
7587
srcs = [

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_first_fake_tensor,
1515
insert_q_dq_pair,
1616
)
17-
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
17+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op
1818
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.pass_base import ExportPass, PassResult
@@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs):
4242
return args[0]
4343

4444

45+
register_passable_op(torch.ops.passthrough_to_tosa._transpose)
46+
47+
4548
class AnnotateChannelsLastDimOrder(ExportPass):
4649
"""
4750
Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order

backends/arm/_passes/insert_squeeze_after_sum_pass.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88

99
import torch
1010
import torch.fx
11-
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair
12-
13-
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
1412
from executorch.exir.dialects._ops import ops as exir_ops
1513
from executorch.exir.pass_base import ExportPass, PassResult
1614

@@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass):
2826
sum(dims, keep_dim = False)
2927
After pass:
3028
sum(dims, keep_dim = True)
31-
(q)
32-
(dq)
3329
squeeze(dim = dims)
3430
"""
3531

@@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule):
4541
continue
4642

4743
dim_list = cast(list[int], sum_node.args[1])
48-
quantized = is_quant_node(sum_node)
49-
if quantized:
50-
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
51-
qparams = qparams + (torch.int8,)
52-
else:
53-
qparams = None
5444

5545
# Add keep_dim = True arg to sum node.
5646
sum_node.args = sum_node.args[0:2] + (True,)
@@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule):
6151
)
6252
sum_node.replace_all_uses_with(squeeze_node)
6353
squeeze_node.args = (sum_node, dim_list)
64-
if quantized:
65-
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
6654
graph_module.graph.eliminate_dead_code()
6755
graph_module.recompile()
6856
graph_module = super().call(graph_module).graph_module

backends/arm/_passes/size_adjust_conv2d_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import cast, Optional
1010

1111
import torch.fx
12-
from executorch.backends.arm.tosa_quant_utils import is_quant_node
12+
from executorch.backends.arm.tosa_quant_utils import is_node_quantized
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515
from torch._ops import OpOverload
@@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule):
113113
slice_node = graph.create_node(
114114
"call_function", self.slice_op, (last_node,) + args
115115
)
116-
if is_quant_node(last_node):
116+
if is_node_quantized(last_node):
117117
q_params = last_node.args[1:]
118118
dq_node = insert_q_dq_pair(
119119
graph_module.graph, slice_node, q_params

backends/arm/arm_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self):
5252
# TODO MLETORCH-265 Remove permute_nhwc flag
5353
self.permute_nhwc = False
5454
self.quantize_io = False
55+
self.tosa_version = None
5556

5657
def ethosu_compile_spec(
5758
self,

backends/arm/operators/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ python_library(
77
typing = True,
88
deps = [
99
"//executorch/backends/arm:tosa_mapping",
10+
"//executorch/backends/arm:tosa_specification",
1011
],
1112
)
1213

backends/arm/operators/op_bmm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
register_node_visitor,
1515
)
1616
from executorch.backends.arm.tosa_mapping import TosaArg
17-
from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args
17+
from executorch.backends.arm.tosa_quant_utils import (
18+
build_rescale,
19+
get_quant_arg_downstream,
20+
get_quant_arg_upstream,
21+
)
1822
from executorch.backends.arm.tosa_utils import get_two_inputs
1923
from serializer.tosa_serializer import TosaOp
2024

@@ -42,8 +46,10 @@ def define_node(
4246
# For INT8, we need to get the zero points and add an intermediate tensor
4347
# for a later rescale.
4448
if is_quant_node:
45-
input0_zp = get_quant_node_args(input0).zp
46-
input1_zp = get_quant_node_args(input1).zp
49+
input0_q_params = get_quant_arg_upstream(input0)
50+
input1_q_params = get_quant_arg_upstream(input1)
51+
input0_zp = input0_q_params.zp
52+
input1_zp = input1_q_params.zp
4753
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
4854
bmm_output_name = bmm_result.name
4955
else:
@@ -63,9 +69,7 @@ def define_node(
6369

6470
# As INT8 accumulates into INT32, we need to rescale it back to INT8
6571
if is_quant_node:
66-
input0_q_params = get_quant_node_args(input0)
67-
input1_q_params = get_quant_node_args(input1)
68-
output_q_params = get_quant_node_args(list(node.users)[0])
72+
output_q_params = get_quant_arg_downstream(list(node.users)[0])
6973

7074
final_output_scale = (
7175
input0_q_params.scale * input1_q_params.scale

0 commit comments

Comments
 (0)