Skip to content

Commit 0836968

Browse files
author
jorgep31415
committed
Update on "[ET-VK] Only save_cache the first time"
Add a check in `save_cache` to return early if the cache file already exists. Currently we append the same cache data to that file which makes no difference to model-load time. Differential Revision: [D66179919](https://our.internmc.facebook.com/intern/diff/D66179919/) [ghstack-poisoned]
2 parents 8f344f9 + 7314968 commit 0836968

File tree

18 files changed

+336
-49
lines changed

18 files changed

+336
-49
lines changed

.github/workflows/android-perf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
fail-fast: false
137137
with:
138138
runner: linux.4xlarge
139-
docker-image: executorch-ubuntu-22.04-clang12-android
139+
docker-image: executorch-ubuntu-22.04-qnn-sdk
140140
submodules: 'true'
141141
timeout: 60
142142
upload-artifact: android-models

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ jobs:
302302
fail-fast: false
303303
with:
304304
runner: linux.2xlarge
305-
docker-image: executorch-ubuntu-22.04-clang12-android
305+
docker-image: executorch-ubuntu-22.04-qnn-sdk
306306
submodules: 'true'
307307
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
308308
timeout: 900

.gitmodules

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[submodule "backends/arm/third-party/ethos-u-core-driver"]
22
path = backends/arm/third-party/ethos-u-core-driver
3-
url = https://review.mlplatform.org/ml/ethos-u/ethos-u-core-driver
3+
url = https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git/
44
[submodule "backends/arm/third-party/serialization_lib"]
55
path = backends/arm/third-party/serialization_lib
6-
url = https://review.mlplatform.org/tosa/serialization_lib
6+
url = https://git.mlplatform.org/tosa/serialization_lib.git/
77
[submodule "backends/vulkan/third-party/Vulkan-Headers"]
88
path = backends/vulkan/third-party/Vulkan-Headers
99
url = https://github.com/KhronosGroup/Vulkan-Headers

backends/cadence/aot/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ python_library(
3838
deps = [
3939
":passes",
4040
":utils",
41+
":ops_registrations",
4142
"//caffe2:torch",
4243
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4344
"//executorch/backends/cadence/aot/quantizer:quantizer",
@@ -71,6 +72,8 @@ python_library(
7172
],
7273
deps = [
7374
":utils",
75+
":fuse_ops",
76+
":simplify_ops",
7477
"//caffe2:torch",
7578
"//executorch/exir:pass_base",
7679
"//executorch/exir/dialects:lib",
@@ -163,6 +166,20 @@ python_library(
163166
],
164167
)
165168

169+
python_library(
170+
name = "simplify_ops",
171+
srcs = [
172+
"simplify_ops.py",
173+
],
174+
typing = True,
175+
deps = [
176+
":pass_utils",
177+
"//executorch/backends/cadence/aot:pass_utils",
178+
"//executorch/exir:pass_base",
179+
"//executorch/exir/dialects:lib",
180+
],
181+
)
182+
166183
python_unittest(
167184
name = "test_graph_builder",
168185
srcs = [

backends/cadence/aot/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
from typing import Callable, cast, Optional
1212

13+
import executorch.backends.cadence.aot.ops_registrations # noqa
1314
import torch
1415

1516
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
10221022
return PassResult(graph_module, True)
10231023

10241024

1025-
class FuseOpsInGraph:
1025+
class CadenceFuseOpsInGraph:
10261026
passes = [
10271027
FuseMMWithAdd,
10281028
FuseBatchNormWithConv,

backends/cadence/aot/passes.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
import torch
1212
import torch.fx
1313
import torch.utils._pytree as pytree
14+
from executorch.backends.cadence.aot.fuse_ops import CadenceFuseOpsInGraph
1415
from executorch.backends.cadence.aot.pass_utils import (
1516
CadencePassAttribute,
1617
create_cadence_pass_filter,
1718
register_cadence_pass,
1819
)
20+
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
1921
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
2022
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2123
from executorch.exir.dialects._ops import ops as exir_ops
@@ -346,10 +348,23 @@ def get_passes_in_default_order() -> List[Type[PassType]]:
346348
ReplaceScalarTensorWithFullPass,
347349
RemoveCloneOpsTransformImported,
348350
RemoveNopExpandOpPass,
351+
CadenceFuseOpsInGraph.passes,
349352
ReplaceSqueezeAndUnsqueezeWithViewPass,
350353
ReplacePT2QuantWithCadenceQuantPass,
351354
ReplacePT2DequantWithCadenceDequantPass,
355+
CadenceSimplifyOpsInGraph.passes,
352356
# TODO: add the rest of the passes here.
357+
# InitializePipeline,
358+
# RemoveRedundantOps.passes,
359+
# ReorderOpsInGraph.passes,
360+
# RemoveJarvisNops.passes,
361+
# CadenceFuseOpsInGraph.passes,
362+
# ReplaceOpsInGraph.passes,
363+
# SimplifyOpsInGraph.passes,
364+
# FinalizePipeline,
365+
# FuseFullThenReshapePass,
366+
# FuseTransposeOpPairsPass,
367+
# RemoveNopSliceOrViewOpPass,
353368
]
354369
return pytree.tree_flatten(passes)[0]
355370

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-unsafe
4+
5+
6+
# This file contains all the functions that simplify args of an op
7+
8+
import sys
9+
from typing import Optional
10+
11+
from executorch.backends.cadence.aot.pass_utils import (
12+
CadencePassAttribute,
13+
register_cadence_pass,
14+
)
15+
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.pass_base import ExportPass, ProxyValue
18+
19+
20+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
21+
class SimplifySliceOpPass(ExportPass):
22+
"""
23+
Simplify the start and end indices of slice and slice_scatter ops.
24+
"""
25+
26+
def adjust_slice_range(
27+
self,
28+
length: int,
29+
start: Optional[int] = None,
30+
end: Optional[int] = None,
31+
step: int = 1,
32+
) -> tuple[int, int]:
33+
# Get the start index and end index
34+
start_val = start if start is not None else 0
35+
end_val = end if end is not None else sys.maxsize # 2^63 – 1
36+
37+
# If start_val and end_val are negative, add length to them
38+
if start_val < 0:
39+
start_val += length
40+
if end_val < 0:
41+
end_val += length
42+
43+
# If the start val is still outside the tensor_size along the sliced
44+
# dimension, adjust it accordingly.
45+
if start_val < 0:
46+
start_val = 0
47+
elif start_val >= length:
48+
start_val = length
49+
50+
# If the end val is still outside the tensor_size along the sliced
51+
# dimension, adjust it accordingly.
52+
if end_val < start_val:
53+
end_val = start_val
54+
elif end_val >= length:
55+
end_val = length
56+
57+
# Return the adjusted start and end indices
58+
return (start_val, end_val)
59+
60+
def call_operator(self, op, args, kwargs, meta):
61+
# We are only interested in slice_copy or slice_scatter ops
62+
if op not in {
63+
exir_ops.edge.aten.slice_copy.Tensor,
64+
exir_ops.edge.aten.slice_scatter.default,
65+
}:
66+
return super().call_operator(op, args, kwargs, meta)
67+
68+
# Check if it is a slice_scatter op or not. The slice_scatter op has
69+
# an extra src argument at index 1.
70+
slice_scatter = op == exir_ops.edge.aten.slice_scatter.default
71+
# Parse the arguments
72+
# Extract the tensor to be sliced, and the slicing dimension
73+
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
74+
dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0
75+
# Make dim non-negative
76+
dim = dim if dim >= 0 else dim + in_tensor.dim()
77+
length = in_tensor.size(dim)
78+
79+
# Get the adjusted start and end indices
80+
start_val = args[2 + slice_scatter] if len(args) > 2 + slice_scatter else None
81+
end_val = args[3 + slice_scatter] if len(args) > 3 + slice_scatter else None
82+
step = args[4 + slice_scatter] if len(args) > 4 + slice_scatter else 1
83+
(start_val, end_val) = self.adjust_slice_range(length, start_val, end_val, step)
84+
85+
# If the start_val is geq end_val, then we can return an empty tensor
86+
# for slice op, or input for slice_scatter op.
87+
if start_val >= end_val and slice_scatter:
88+
return args[0]
89+
if start_val >= end_val:
90+
empty_shape = [x for x in in_tensor.shape if x != 0]
91+
empty_shape[dim] = 0
92+
return super().call_operator(
93+
exir_ops.edge.aten.full.default,
94+
(tuple(empty_shape), 0),
95+
{"dtype": in_tensor.dtype},
96+
meta,
97+
)
98+
99+
# Create new args
100+
new_args = (
101+
(args[0],)
102+
+ ((args[1],) if slice_scatter else ())
103+
+ (dim, start_val, end_val, step)
104+
)
105+
return super().call_operator(op, new_args, kwargs, meta)
106+
107+
108+
# This class encapsulates all the functions that simplify the op's args
109+
class CadenceSimplifyOpsInGraph:
110+
passes = [
111+
SimplifySliceOpPass,
112+
]

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,94 @@
1414
QuantizationConfig,
1515
)
1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from torch.ao.quantization.observer import MinMaxObserver
1718
from torch.ao.quantization.quantizer import (
1819
QuantizationAnnotation,
1920
SharedQuantizationSpec,
2021
)
2122
from torch.fx import Node
2223

2324

25+
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
26+
"""
27+
This function is specific for matmul op 16a8w.
28+
"""
29+
30+
def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
31+
input_qspec_map = {}
32+
input_act = node.args[0]
33+
input_spec = quantization_config.input_activation
34+
input_qspec_map[input_act] = input_spec
35+
36+
input_act1 = node.args[1]
37+
input_spec1 = quantization_config.weight
38+
input_qspec_map[input_act1] = input_spec1
39+
40+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
41+
input_qspec_map=input_qspec_map,
42+
output_qspec=quantization_config.output_activation,
43+
_annotated=True,
44+
)
45+
46+
def annotate_cat(node: Node, quantization_config: QuantizationConfig):
47+
input_nodes = node.args[0]
48+
49+
first_input_node = input_nodes[0]
50+
input_qspec_map = {}
51+
input_qspec_map[first_input_node] = quantization_config.input_activation
52+
share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
53+
(first_input_node, node)
54+
)
55+
56+
for input_node in input_nodes[1:]:
57+
if input_node not in input_qspec_map:
58+
input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
59+
60+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
61+
input_qspec_map=input_qspec_map,
62+
output_qspec=share_qparams_with_input_act0_qspec,
63+
_annotated=True,
64+
)
65+
66+
def annotate_single_in_single_out(
67+
node: Node, quantization_config: QuantizationConfig
68+
) -> None:
69+
70+
input_qspec_map = {}
71+
input_act = node.args[0]
72+
input_qspec_map[input_act] = quantization_config.input_activation
73+
74+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
75+
input_qspec_map=input_qspec_map,
76+
output_qspec=quantization_config.output_activation,
77+
_annotated=True,
78+
)
79+
80+
def annotate_matmul_input1(node: Node):
81+
quantization_config_8a8w = get_default_8bit_qnn_ptq_config(
82+
act_symmetric=True, act_observer=MinMaxObserver
83+
)
84+
while isinstance(node, Node) and node.op == "call_function":
85+
if node.target in [
86+
torch.ops.aten.permute.default,
87+
torch.ops.aten.transpose.int,
88+
]:
89+
annotate_single_in_single_out(node, quantization_config_8a8w)
90+
node = node.args[0]
91+
elif node.target == torch.ops.aten.cat.default:
92+
annotate_cat(node, quantization_config_8a8w)
93+
node = node.args[0][0]
94+
else:
95+
node = node.args[0]
96+
97+
quantization_config_16a8w = get_16a8w_qnn_ptq_config(act_observer=MinMaxObserver)
98+
99+
for node in gm.graph.nodes:
100+
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
101+
annotate_matmul(node, quantization_config_16a8w)
102+
annotate_matmul_input1(node.args[1])
103+
104+
24105
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
25106
"""
26107
This function is specific for llama matmul op 16a8w.

devtools/bundled_program/schema/scalar_type.fbs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,19 @@ enum ScalarType : byte {
1818
FLOAT = 6,
1919
DOUBLE = 7,
2020
BOOL = 11,
21-
// TODO(jakeszwe): Verify these are unused and then remove support
2221
QINT8 = 12,
2322
QUINT8 = 13,
2423
QINT32 = 14,
2524
QUINT4X2 = 16,
2625
QUINT2X4 = 17,
2726
BITS16 = 22,
27+
FLOAT8E5M2 = 23,
28+
FLOAT8E4M3FN = 24,
29+
FLOAT8E5M2FNUZ = 25,
30+
FLOAT8E4M3FNUZ = 26,
31+
UINT16 = 27,
32+
UINT32 = 28,
33+
UINT64 = 29,
2834
// Types currently not implemented.
2935
// COMPLEXHALF = 8,
3036
// COMPLEXFLOAT = 9,

0 commit comments

Comments
 (0)