Skip to content

Commit f726c4e

Browse files
authored
[ET-VK] Migrate off of xnnpack_quantizer_utils (#12998)
# Context Eventually as the vulkan_quantizer file expands, we will need to migrate into a custom utils file and stop depending on the xnnpack_quantizer_utils. We migrate only the minimal amount of functions necessary to ensure the vulkan_quantizer works. # Changes We create a new file `vulkan_quantizer_utils.py` and migrate off of `xnnpack_quantizer_utils.py` in `vulkan_quantizer`. There are no specific modifications necessary to work separate from xnnpack utils except bits_to_range to allow not needing to specify the ranges everytime. Differential Revision: [D78290055](https://our.internmc.facebook.com/intern/diff/D78290055/) [ghstack-poisoned]
1 parent 9d8c008 commit f726c4e

File tree

3 files changed

+216
-4
lines changed

3 files changed

+216
-4
lines changed

backends/vulkan/quantizer/TARGETS

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,17 @@ oncall("executorch")
44

55
python_library(
66
name = "vulkan_quantizer",
7-
srcs = [
8-
"vulkan_quantizer.py",
7+
srcs = ["vulkan_quantizer.py"],
8+
deps = [
9+
":vulkan_quantizer_utils",
10+
"//caffe2:torch",
911
],
12+
)
13+
14+
python_library(
15+
name = "vulkan_quantizer_utils",
16+
srcs = ["vulkan_quantizer_utils.py"],
1017
deps = [
1118
"//caffe2:torch",
12-
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer_utils",
1319
],
1420
)

backends/vulkan/quantizer/vulkan_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from typing import Callable, Optional
1313

1414
import torch
15-
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
15+
from executorch.backends.vulkan.quantizer.vulkan_quantizer_utils import (
1616
_convert_scalars_to_attrs,
1717
OP_TO_ANNOTATOR,
1818
propagate_annotation,
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Callable, Optional
10+
11+
import torch
12+
from torch.fx import Node
13+
from torchao.quantization.pt2e.quantizer import (
14+
annotate_input_qspec_map,
15+
annotate_output_qspec,
16+
get_bias_qspec,
17+
get_input_act_qspec,
18+
get_output_act_qspec,
19+
get_weight_qspec,
20+
QuantizationAnnotation,
21+
QuantizationConfig,
22+
SharedQuantizationSpec,
23+
)
24+
from torchao.quantization.pt2e.utils import get_new_attr_name_with_prefix
25+
26+
__all__ = [
27+
"OP_TO_ANNOTATOR",
28+
"propagate_annotation",
29+
"_convert_scalars_to_attrs",
30+
]
31+
32+
33+
AnnotatorType = Callable[
34+
[
35+
torch.fx.GraphModule,
36+
Optional[QuantizationConfig],
37+
Optional[Callable[[Node], bool]],
38+
],
39+
Optional[list[list[Node]]],
40+
]
41+
OP_TO_ANNOTATOR: dict[str, AnnotatorType] = {}
42+
43+
44+
def register_annotator(op: str) -> Callable[[AnnotatorType], None]:
45+
def decorator(annotator: AnnotatorType) -> None:
46+
OP_TO_ANNOTATOR[op] = annotator
47+
48+
return decorator
49+
50+
51+
def _is_annotated(nodes: list[Node]) -> bool:
52+
"""
53+
Given a list of nodes (that represents an operator pattern),
54+
check if any of the node is annotated, return True if any of the node
55+
is annotated, otherwise return False
56+
"""
57+
annotated = False
58+
for node in nodes:
59+
annotated = annotated or (
60+
"quantization_annotation" in node.meta
61+
and node.meta["quantization_annotation"]._annotated
62+
)
63+
return annotated
64+
65+
66+
def _mark_nodes_as_annotated(nodes: list[Node]) -> None:
67+
for node in nodes:
68+
if node is not None:
69+
if "quantization_annotation" not in node.meta:
70+
node.meta["quantization_annotation"] = QuantizationAnnotation()
71+
node.meta["quantization_annotation"]._annotated = True
72+
73+
74+
@register_annotator("linear")
75+
def _annotate_linear(
76+
gm: torch.fx.GraphModule,
77+
quantization_config: Optional[QuantizationConfig],
78+
filter_fn: Optional[Callable[[Node], bool]] = None,
79+
) -> Optional[list[list[Node]]]:
80+
annotated_partitions = []
81+
input_act_qspec = get_input_act_qspec(quantization_config)
82+
output_act_qspec = get_output_act_qspec(quantization_config)
83+
weight_qspec = get_weight_qspec(quantization_config)
84+
bias_qspec = get_bias_qspec(quantization_config)
85+
for node in gm.graph.nodes:
86+
if node.op != "call_function" or node.target != torch.ops.aten.linear.default:
87+
continue
88+
if filter_fn and not filter_fn(node):
89+
continue
90+
act_node = node.args[0]
91+
weight_node = node.args[1]
92+
bias_node = None
93+
if len(node.args) > 2:
94+
bias_node = node.args[2]
95+
96+
if _is_annotated([node]) is False: # type: ignore[list-item]
97+
annotate_input_qspec_map(
98+
node,
99+
act_node,
100+
input_act_qspec,
101+
)
102+
annotate_input_qspec_map(
103+
node,
104+
weight_node,
105+
weight_qspec,
106+
)
107+
nodes_to_mark_annotated = [node, weight_node]
108+
if bias_node:
109+
annotate_input_qspec_map(
110+
node,
111+
bias_node,
112+
bias_qspec,
113+
)
114+
nodes_to_mark_annotated.append(bias_node)
115+
annotate_output_qspec(node, output_act_qspec)
116+
_mark_nodes_as_annotated(nodes_to_mark_annotated)
117+
annotated_partitions.append(nodes_to_mark_annotated)
118+
119+
return annotated_partitions
120+
121+
122+
def _is_share_obs_or_fq_op(op: Callable[..., torch.Tensor]) -> bool:
123+
return op in [
124+
torch.ops.aten.relu.default,
125+
torch.ops.aten.hardtanh.default,
126+
torch.ops.aten.hardtanh_.default,
127+
torch.ops.aten.max_pool2d.default,
128+
torch.ops.aten.mean.default,
129+
torch.ops.aten.mean.dim,
130+
torch.ops.aten.permute.default,
131+
torch.ops.aten.permute_copy.default,
132+
torch.ops.aten.squeeze.dim,
133+
torch.ops.aten.squeeze_copy.dim,
134+
torch.ops.aten.adaptive_avg_pool2d.default,
135+
torch.ops.aten.view_copy.default,
136+
torch.ops.aten.view.default,
137+
torch.ops.aten.slice_copy.Tensor,
138+
torch.ops.aten.flatten.using_ints,
139+
]
140+
141+
142+
def propagate_annotation(model: torch.fx.GraphModule) -> None:
143+
for n in model.graph.nodes:
144+
if n.op != "call_function" or not _is_share_obs_or_fq_op(n.target):
145+
continue
146+
147+
prev_node = n.args[0]
148+
if not isinstance(prev_node, Node):
149+
continue
150+
151+
quantization_annotation = prev_node.meta.get("quantization_annotation", None)
152+
if not quantization_annotation:
153+
continue
154+
155+
output_qspec = quantization_annotation.output_qspec
156+
if not output_qspec:
157+
continue
158+
159+
# make sure current node is not annotated
160+
if (
161+
"quantization_annotation" in n.meta
162+
and n.meta["quantization_annotation"]._annotated
163+
):
164+
continue
165+
166+
shared_qspec = SharedQuantizationSpec(prev_node)
167+
# propagate the previous output_qspec to the current node
168+
n.meta["quantization_annotation"] = QuantizationAnnotation(
169+
input_qspec_map={
170+
prev_node: shared_qspec,
171+
},
172+
output_qspec=shared_qspec,
173+
_annotated=True,
174+
)
175+
176+
177+
def _convert_scalars_to_attrs(model: torch.fx.GraphModule) -> torch.fx.GraphModule:
178+
for n in model.graph.nodes:
179+
if n.op != "call_function" or n.target not in [
180+
torch.ops.aten.add.Tensor,
181+
torch.ops.aten.mul.Tensor,
182+
]:
183+
continue
184+
args = list(n.args)
185+
new_args = []
186+
for i in range(len(args)):
187+
if isinstance(args[i], torch.fx.Node):
188+
new_args.append(args[i])
189+
continue
190+
prefix = "_tensor_constant_"
191+
get_new_attr_name = get_new_attr_name_with_prefix(prefix)
192+
tensor_constant_name = get_new_attr_name(model)
193+
float_tensor = torch.tensor(float(args[i]))
194+
model.register_buffer(tensor_constant_name, float_tensor)
195+
fake_mode = n.meta["val"].fake_mode
196+
with model.graph.inserting_before(n):
197+
get_attr_node = model.graph.create_node(
198+
"get_attr", tensor_constant_name, (), {}
199+
)
200+
get_attr_node.meta["val"] = fake_mode.from_tensor(
201+
float_tensor, static_shapes=True
202+
)
203+
new_args.append(get_attr_node)
204+
n.args = tuple(new_args)
205+
model.recompile()
206+
return model

0 commit comments

Comments
 (0)