Skip to content

Commit 6d0a0be

Browse files
committed
Update base for Update on "[Executorch] mul broadcast update"
Handle broadcast for > 2D tensors in optimized library. For now broadcast across only non 0th and (N-1)st dim is supported in optimized path. Differential Revision: [D64156862](https://our.internmc.facebook.com/intern/diff/D64156862/) [ghstack-poisoned]
2 parents 6a5171e + 9890c24 commit 6d0a0be

File tree

27 files changed

+683
-470
lines changed

27 files changed

+683
-470
lines changed

.github/workflows/update-viablestrict.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
with:
2121
repository: pytorch/executorch
2222
stable-branch: viable/strict
23-
requires: '[\"pull\", \"lint\", \"trunk\", \"Build documentation\", \"^Android$\", \"^Apple$\"]'
23+
requires: '[\"pull\", \"lint\", \"trunk\", \"Build documentation\", \"^Apple$\"]'
2424
secret-bot-token: ${{ secrets.UPDATEBOT_TOKEN }}
2525
clickhouse-url: ${{ secrets.CLICKHOUSE_URL }}
2626
clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }}

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
QuantizationConfig,
1313
)
1414
from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY
15+
from executorch.exir.dialects._ops import ops as exir_ops
1516
from torch.ao.quantization.quantizer import (
1617
QuantizationAnnotation,
1718
SharedQuantizationSpec,
@@ -144,3 +145,35 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig):
144145
for node in gm.graph.nodes:
145146
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
146147
annotate_matmul(node, quantization_config_16a8w)
148+
149+
150+
def get_custom_quant_ios_dtype(
151+
cache_shape: torch.Size,
152+
node: torch.fx.Node,
153+
kv_dtype=torch.uint8,
154+
sharding_dtype=torch.uint16,
155+
):
156+
"""
157+
This function is specific for llama inputs and outputs
158+
"""
159+
if node.op == "placeholder" and "attention_sdpa_kv_cache_past_" in node.name:
160+
return kv_dtype
161+
162+
# Tag index put node before copy node, because copy is a skipped node in qnn
163+
if (
164+
exir_ops.edge.aten.index_put.default == node.target
165+
and node.meta["val"].shape == cache_shape
166+
):
167+
return kv_dtype
168+
169+
# Tag sharding io
170+
if exir_ops.edge.llama.fallback.default in [
171+
u.target for u in list(node.users.keys())
172+
] + [node.target]:
173+
return sharding_dtype
174+
175+
# Tag index op as quantized tensors. It is caused by sharding
176+
if exir_ops.edge.aten.index.Tensor in [
177+
u.target for u in list(node.users.keys())
178+
] + [node.target]:
179+
return sharding_dtype

backends/qualcomm/utils/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
QCOM_PASS_EXPAND_BROADCAST_SHAPE,
7272
QCOM_PASS_SKIP_ADVANCED_REQUANT,
7373
QCOM_QNN_COMPILE_SPEC,
74+
QCOM_QUANTIZED_IO,
7475
)
7576

7677
from executorch.exir import ExirExportedProgram
@@ -876,3 +877,12 @@ def get_soc_to_chipset_map():
876877
"SM8475": QcomChipset.SM8475,
877878
"SM8450": QcomChipset.SM8450,
878879
}
880+
881+
882+
def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
883+
"""
884+
Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
885+
"""
886+
for node in gm.graph.nodes:
887+
if dtype := get_quant_io_dtype_fn(node):
888+
node.meta[QCOM_QUANTIZED_IO] = dtype

backends/transforms/fuse_conv_with_clamp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
import sys
88

9+
import executorch.backends.vulkan.custom_ops_lib # noqa
10+
911
import torch
10-
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
11-
conv_with_clamp_op,
12-
)
1312

1413
from executorch.exir.dialects._ops import ops as exir_ops
1514
from executorch.exir.pass_base import ExportPass, PassResult

backends/transforms/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def define_common_targets():
7070
deps = [
7171
":utils",
7272
"//caffe2:torch",
73-
"//executorch/backends/vulkan/_passes:custom_ops_defs",
73+
"//executorch/backends/vulkan:custom_ops_lib",
7474
"//executorch/exir:pass_base",
7575
"//executorch/exir:sym_util",
7676
"//executorch/exir/dialects:lib",

backends/vulkan/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ set(vulkan_standard_shaders_cpp ${generated_spv_cpp})
8383
set(SCHEMA_INCLUDE_DIR ${CMAKE_BINARY_DIR}/schema/include)
8484

8585
set(GENERATED_HEADER
86-
${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/schema_generated.h
86+
${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/serialization/schema_generated.h
8787
)
8888

8989
add_custom_command(
9090
OUTPUT ${GENERATED_HEADER}
9191
COMMAND
9292
${FLATC_EXECUTABLE} --cpp --cpp-std c++11 --scoped-enums -o
93-
"${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/" ${_vulkan_schema__srcs}
93+
"${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/serialization/" ${_vulkan_schema__srcs}
9494
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
9595
COMMENT "Generating vulkan_schema headers"
9696
VERBATIM

backends/vulkan/TARGETS

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,4 @@
1-
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
21
load(":targets.bzl", "define_common_targets")
3-
42
oncall("executorch")
53

64
define_common_targets(is_fbcode = True)
7-
8-
runtime.python_library(
9-
name = "vulkan_preprocess",
10-
srcs = [
11-
"serialization/vulkan_graph_builder.py",
12-
"serialization/vulkan_graph_schema.py",
13-
"serialization/vulkan_graph_serialize.py",
14-
"vulkan_preprocess.py",
15-
],
16-
resources = [
17-
"serialization/schema.fbs",
18-
],
19-
visibility = [
20-
"//executorch/...",
21-
"//executorch/vulkan/...",
22-
"@EXECUTORCH_CLIENTS",
23-
],
24-
deps = [
25-
"//executorch/backends/transforms:addmm_mm_to_linear",
26-
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
27-
"//executorch/backends/transforms:fuse_conv_with_clamp",
28-
"//executorch/backends/transforms:fuse_dequant_linear",
29-
"//executorch/backends/transforms:fuse_view_copy",
30-
"//executorch/backends/transforms:remove_clone_ops",
31-
"//executorch/backends/vulkan/_passes:vulkan_passes",
32-
"//executorch/exir:graph_module",
33-
"//executorch/exir/_serialize:_bindings",
34-
"//executorch/exir/_serialize:lib",
35-
"//executorch/exir/backend:backend_details",
36-
],
37-
)

backends/vulkan/_passes/TARGETS

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
33

44
oncall("executorch")
55

6-
runtime.python_library(
7-
name = "custom_ops_defs",
8-
srcs = [
9-
"custom_ops_defs.py",
10-
],
11-
visibility = [
12-
"//executorch/...",
13-
"@EXECUTORCH_CLIENTS",
14-
],
15-
deps = [
16-
"//caffe2:torch",
17-
],
18-
)
19-
20-
python_unittest(
21-
name = "test_custom_ops",
22-
srcs = [
23-
"test_custom_ops.py",
24-
],
25-
deps = [
26-
":custom_ops_defs",
27-
"//caffe2:torch",
28-
],
29-
)
30-
316
runtime.python_library(
327
name = "insert_prepack_nodes",
338
srcs = ["insert_prepack_nodes.py"],
@@ -62,7 +37,7 @@ runtime.python_library(
6237
"//executorch/backends/...",
6338
],
6439
deps = [
65-
":custom_ops_defs",
40+
"//executorch/backends/vulkan:custom_ops_lib",
6641
"//pytorch/ao:torchao",
6742
]
6843
)

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,27 @@
66

77
# pyre-strict
88

9-
from typing import List
10-
11-
import executorch.backends.vulkan._passes.custom_ops_defs # noqa
9+
import executorch.backends.vulkan.custom_ops_lib # noqa
1210

1311
import torch
1412

13+
from executorch.backends.vulkan.op_registry import handles_own_prepacking
14+
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

1717
from torch._export.utils import is_buffer, is_param
1818
from torch.export import ExportedProgram
1919

20-
USES_WEIGHTS: List[torch._ops.OpOverload] = [
21-
exir_ops.edge.aten.embedding.default,
22-
exir_ops.edge.aten.convolution.default,
23-
exir_ops.edge.et_vk.conv_with_clamp.default,
24-
exir_ops.edge.aten.linear.default,
25-
exir_ops.edge.aten._weight_int8pack_mm.default,
26-
exir_ops.edge.et_vk.linear_weight_int4.default,
27-
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
28-
exir_ops.edge.aten.native_layer_norm.default,
29-
"llama::sdpa_with_kv_cache",
30-
]
31-
3220

3321
def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
3422
"""
3523
Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
3624
is responsible for transferring the tensor data, which is serialized with the model,
3725
to a GPU tensor object during the prepacking stage of model execution.
3826
39-
Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will
40-
prefer to handle prepacking within the operator. For these ops, the constant tensor
41-
data will be passed directly as an argument into the operator implementation.
27+
Some operators are performance sensitive and will prefer to handle prepacking within
28+
the operator. For these ops, the constant tensor data will be passed directly as an
29+
argument into the operator implementation.
4230
"""
4331

4432
def is_get_attr_node(node: torch.fx.Node) -> bool:
@@ -58,22 +46,21 @@ def is_param_node(node: torch.fx.Node) -> bool:
5846
or is_constant(node)
5947
)
6048

61-
def is_non_weight_param_tensor(node: torch.fx.Node) -> bool:
49+
def prepack_not_required(node: torch.fx.Node) -> bool:
6250
if not is_param_node(node):
63-
return False
51+
return True
6452

6553
for user in node.users:
66-
if user.op == "call_function" and (
67-
# pyre-ignore [16]
68-
user.target in USES_WEIGHTS
69-
or user.target.name() in USES_WEIGHTS
54+
if user.op == "call_function" and handles_own_prepacking(
55+
# pyre-ignore
56+
user.target
7057
):
71-
return False
58+
return True
7259

73-
return True
60+
return False
7461

7562
for node in program.graph_module.graph.nodes:
76-
if not is_non_weight_param_tensor(node):
63+
if prepack_not_required(node):
7764
continue
7865

7966
with program.graph_module.graph.inserting_after(node):

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import logging
22
from typing import Any, Callable, Dict, Optional, Type
33

4+
import executorch.backends.vulkan.custom_ops_lib # noqa
5+
46
import torch
57
import torch.nn.functional as F
68

7-
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
8-
linear_weight_int4_op,
9-
)
10-
119
from torchao.quantization.GPTQ import _check_linear_int4_k
1210
from torchao.quantization.unified import Quantizer
1311
from torchao.quantization.utils import groupwise_affine_quantize_tensor

0 commit comments

Comments
 (0)