Skip to content

Commit 9f35045

Browse files
Merge branch 'pytorch:main' into main
2 parents 7283e4c + 8673567 commit 9f35045

File tree

76 files changed

+1901
-604
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1901
-604
lines changed

.ci/docker/build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ case "${IMAGE_NAME}" in
4141
LINTRUNNER=""
4242
CLANG_VERSION=12
4343
# From https://developer.android.com/ndk/downloads
44-
ANDROID_NDK_VERSION=r26c
44+
ANDROID_NDK_VERSION=r27b
4545
;;
4646
*)
4747
echo "Invalid image name ${IMAGE_NAME}"

.github/workflows/android-release-artifacts.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,24 @@ concurrency:
1313
cancel-in-progress: true
1414

1515
jobs:
16+
check-if-aar-exists:
17+
name: check-if-aar-exists
18+
runs-on: ubuntu-22.04
19+
timeout-minutes: 10
20+
steps:
21+
- name: Check if this RC version is already in S3
22+
shell: bash
23+
run: |
24+
VERSION="${{ inputs.version }}"
25+
if curl -I "https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar" | grep "200 OK"; then
26+
echo "AAR already exists at https://ossci-android.s3.amazonaws.com/executorch/release/${VERSION}/executorch.aar"
27+
echo "Will skip build/upload"
28+
exit 1
29+
fi
30+
1631
build-aar:
1732
name: build-aar
33+
needs: check-if-aar-exists
1834
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
1935
with:
2036
runner: linux.2xlarge

.github/workflows/android.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
# NB: Use metal install for KVM support to run the emulator faster
5454
runs-on: linux.24xl.spr-metal
5555
env:
56-
ANDROID_NDK_VERSION: r26c
56+
ANDROID_NDK_VERSION: r27b
5757
API_LEVEL: 34
5858
steps:
5959
- name: Setup SSH (Click me for login details)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
ConvertSplitToSlicePass,
2020
)
2121
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
22+
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
23+
InsertSqueezeAfterSumPass,
24+
)
2225
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
2326
ConvertMeanDimToAveragePool,
2427
)
@@ -49,6 +52,7 @@ def transform_to_backend_pipeline(
4952
self.add_pass(ConvertExpandCopyToRepeatPass())
5053
self.add_pass(ConvertMeanDimToAveragePool())
5154
self.add_pass(DecomposeDivPass())
55+
self.add_pass(InsertSqueezeAfterSumPass())
5256
self.add_pass(ConvertSplitToSlicePass())
5357
for spec in compile_spec:
5458
if spec.key == "permute_memory_format":
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import cast
8+
9+
import torch
10+
import torch.fx
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair
12+
13+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node
14+
from executorch.exir.dialects._ops import ops as exir_ops
15+
from executorch.exir.pass_base import ExportPass, PassResult
16+
17+
18+
class InsertSqueezeAfterSumPass(ExportPass):
19+
"""
20+
In Pytorch, the default behaviour of Tensor.sum is to squeeze
21+
the dimension that is summed (keep_dim = False).
22+
However, in TOSA, REDUCE_SUM always preserves the
23+
rank of the input (keep_dim = True).
24+
To get a 1-1 mapping in the sum lowering, normalize the
25+
keep_dim = False case to keep_dim = True and add squeeze ops.
26+
27+
Original:
28+
sum(dims, keep_dim = False)
29+
After pass:
30+
sum(dims, keep_dim = True)
31+
(q)
32+
(dq)
33+
squeeze(dim = dims)
34+
"""
35+
36+
def call(self, graph_module: torch.fx.GraphModule):
37+
for node in graph_module.graph.nodes:
38+
if node.op != "call_function":
39+
continue
40+
if node.target != exir_ops.edge.aten.sum.dim_IntList:
41+
continue
42+
sum_node = cast(torch.fx.Node, node)
43+
keep_dim = cast(bool, sum_node.args[2] if len(sum_node.args) > 2 else False)
44+
if keep_dim:
45+
continue
46+
47+
dim_list = cast(list[int], sum_node.args[1])
48+
quantized = is_quant_node(sum_node)
49+
if quantized:
50+
qparams = get_quant_node_args(sum_node.all_input_nodes[0])
51+
qparams = qparams + (torch.int8,)
52+
else:
53+
qparams = None
54+
55+
# Add keep_dim = True arg to sum node.
56+
sum_node.args = sum_node.args[0:2] + (True,)
57+
58+
with graph_module.graph.inserting_after(sum_node):
59+
squeeze_node = create_node(
60+
graph_module.graph, exir_ops.edge.aten.squeeze_copy.dims, ()
61+
)
62+
sum_node.replace_all_uses_with(squeeze_node)
63+
squeeze_node.args = (sum_node, dim_list)
64+
if quantized:
65+
sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams)
66+
graph_module.graph.eliminate_dead_code()
67+
graph_module.recompile()
68+
graph_module = super().call(graph_module).graph_module
69+
return PassResult(graph_module, True)

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
6464
exir_ops.edge.aten._softmax.default,
6565
exir_ops.edge.aten.slice_copy.Tensor,
6666
exir_ops.edge.aten.sub.Tensor,
67+
exir_ops.edge.aten.sum.dim_IntList,
6768
exir_ops.edge.aten.view_copy.default,
6869
exir_ops.edge.aten.clone.default,
6970
exir_ops.edge.aten.mean.dim,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
op_softmax,
3636
op_squeeze,
3737
op_sub,
38+
op_sum,
3839
op_unsqueeze,
3940
op_view,
4041
)

backends/arm/operators/op_sum.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import cast, List
7+
8+
import executorch.backends.arm.tosa_quant_utils as tqutils
9+
import executorch.backends.arm.tosa_utils as tutils
10+
11+
import serializer.tosa_serializer as ts
12+
from executorch.backends.arm.operators.node_visitor import (
13+
NodeVisitor,
14+
register_node_visitor,
15+
)
16+
from executorch.backends.arm.tosa_mapping import TosaArg
17+
from serializer.tosa_serializer import TosaOp
18+
from torch.fx import Node
19+
20+
21+
@register_node_visitor
22+
class AddVisitor(NodeVisitor):
23+
target = "aten.sum.dim_IntList"
24+
25+
def __init__(self, *args):
26+
super().__init__(*args)
27+
28+
def define_node(
29+
self,
30+
node: Node,
31+
tosa_graph: ts.TosaSerializer,
32+
inputs: List[TosaArg],
33+
output: TosaArg,
34+
is_quant_node: bool,
35+
) -> None:
36+
input_node = inputs[0]
37+
input_shape = list(input_node.shape)
38+
dim_list = cast(list[int], inputs[1].special)
39+
dim_list = [dim % len(input_node.shape) for dim in dim_list]
40+
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
41+
assert keep_dim, "This case should be handled by InsertSqueezeAfterSumPass"
42+
43+
if is_quant_node:
44+
45+
# Rescale input to 32 bit
46+
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
47+
[node.all_input_nodes[0]], tosa_graph
48+
)
49+
50+
prev_node = rescaled_inputs[0]
51+
reduced_shape = input_shape
52+
53+
# Reduce all dims in dim_list one-by-one.
54+
for dim in dim_list:
55+
# When reduced, the size of the dim becomes 1.
56+
reduced_shape[dim] = 1
57+
58+
attr = ts.TosaSerializerAttribute()
59+
attr.AxisAttribute(input_node.dim_order.index(dim))
60+
61+
next_node = tosa_graph.addIntermediate(
62+
tutils.tosa_shape(reduced_shape, input_node.dim_order),
63+
dtype=ts.DType.INT32,
64+
)
65+
66+
tosa_graph.addOperator(
67+
TosaOp.Op().REDUCE_SUM, [prev_node.name], [next_node.name], attr
68+
)
69+
70+
prev_node = next_node
71+
tqutils.rescale_node_back_to_int8(node, prev_node, scale, tosa_graph)
72+
else:
73+
input_name = input_node.name
74+
reduced_shape = input_shape
75+
76+
# Reduce all dims in dim_list one-by-one.
77+
for dim in dim_list:
78+
# When reduced, the size of the dim becomes 1
79+
reduced_shape[dim] = 1
80+
81+
attr = ts.TosaSerializerAttribute()
82+
attr.AxisAttribute(input_node.dim_order.index(dim))
83+
84+
if dim == dim_list[-1]:
85+
output_name = output.name
86+
else:
87+
output_name = tosa_graph.addIntermediate(
88+
tutils.tosa_shape(reduced_shape, input_node.dim_order),
89+
dtype=ts.DType.FP32,
90+
).name
91+
92+
tosa_graph.addOperator(
93+
TosaOp.Op().REDUCE_SUM, [input_name], [output_name], attr
94+
)
95+
96+
input_name = output_name

backends/arm/quantizer/arm_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,11 +267,11 @@ class ArmQuantizer(Quantizer):
267267
"add",
268268
"sub",
269269
"mul",
270-
"sigmoid",
271270
"mm",
272271
"cat",
273272
"one_to_one",
274273
"generic",
274+
"sum",
275275
]
276276

277277
def __init__(self) -> None:

backends/arm/quantizer/quantization_annotation/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ def decorator(annotator: AnnotatorType):
5959
mm_annotator,
6060
mul_annotator,
6161
one_to_one_annotator,
62-
sigmoid_annotator,
6362
sub_annotator,
63+
sum_annotator,
6464
)

0 commit comments

Comments
 (0)