Skip to content

Commit 97a724f

Browse files
cccclaifacebook-github-bot
authored andcommitted
Use Q_ANNOTATION_KEY (#12728)
Summary: Pull Request resolved: #12728 pytorch/ao#2525 introduced Q_ANNOTATION_KEY to avoid manually typing "quantization_annotation". Trying to apply it in our codebase Reviewed By: jerryzh168 Differential Revision: D78193037
1 parent ef10a35 commit 97a724f

File tree

14 files changed

+120
-122
lines changed

14 files changed

+120
-122
lines changed

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,21 @@
1818
from torch.fx import GraphModule, Node
1919

2020
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
21+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2122

2223

2324
def is_annotated(node: Node) -> bool:
2425
"""Given a node return whether the node is annotated."""
2526
return (
26-
"quantization_annotation" in node.meta
27-
and cast(
28-
QuantizationAnnotation, node.meta["quantization_annotation"]
29-
)._annotated
27+
Q_ANNOTATION_KEY in node.meta
28+
and cast(QuantizationAnnotation, node.meta[Q_ANNOTATION_KEY])._annotated
3029
)
3130

3231

3332
def is_output_annotated(node: Node) -> bool:
3433
"""Given a node, return whether the output of the node is annotated."""
35-
if "quantization_annotation" in node.meta:
36-
annotation = cast(QuantizationAnnotation, node.meta["quantization_annotation"])
34+
if Q_ANNOTATION_KEY in node.meta:
35+
annotation = cast(QuantizationAnnotation, node.meta[Q_ANNOTATION_KEY])
3736
return annotation._annotated and annotation.output_qspec is not None
3837
else:
3938
return False
@@ -43,9 +42,9 @@ def mark_node_as_annotated(node: Node) -> None:
4342
"""Marks node as annotated. If needed, an empty QuantizationAnnotation is added
4443
to the quantization_annotation node meta entry.
4544
"""
46-
if "quantization_annotation" not in node.meta:
47-
node.meta["quantization_annotation"] = QuantizationAnnotation()
48-
node.meta["quantization_annotation"]._annotated = True
45+
if Q_ANNOTATION_KEY not in node.meta:
46+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
47+
node.meta[Q_ANNOTATION_KEY]._annotated = True
4948

5049

5150
def is_ok_for_quantization(node: Node, gm: GraphModule):

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
QuantizationSpec,
4343
Quantizer,
4444
)
45+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
4546

4647

4748
act_qspec_asym8s = QuantizationSpec(
@@ -127,7 +128,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
127128

128129
for output, *custom_spec in anchors.output:
129130
# pyre-ignore[16]: no attribute
130-
output.meta["quantization_annotation"] = QuantizationAnnotation(
131+
output.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
131132
# pyre-ignore[6]: incompatible parameter type
132133
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
133134
_annotated=True,
@@ -143,7 +144,7 @@ def annotate_inputs(
143144
for node, idx, *custom_spec in inputs:
144145
# pyre-ignore[16]: no attribute
145146
annotation = node.meta.get(
146-
"quantization_annotation",
147+
Q_ANNOTATION_KEY,
147148
QuantizationAnnotation(_annotated=True),
148149
)
149150
arg = (
@@ -157,21 +158,21 @@ def annotate_inputs(
157158
custom_spec[0] if custom_spec else spec
158159
)
159160
# pyre-ignore[16]: no attribute
160-
node.meta["quantization_annotation"] = annotation
161+
node.meta[Q_ANNOTATION_KEY] = annotation
161162

162163
def annotate_weights_or_biases(
163164
weights_or_biases: List[Tuple[fx.Node, int]],
164165
spec: Optional[QuantizationSpec],
165166
) -> None:
166167
for node, idx, *custom_spec in weights_or_biases:
167168
annotation = node.meta.get(
168-
"quantization_annotation",
169+
Q_ANNOTATION_KEY,
169170
QuantizationAnnotation(_annotated=True),
170171
)
171172
annotation.input_qspec_map[node.args[idx]] = (
172173
custom_spec[0] if custom_spec else spec
173174
)
174-
node.meta["quantization_annotation"] = annotation
175+
node.meta[Q_ANNOTATION_KEY] = annotation
175176

176177
# pyre-ignore[6]: incompatible parameter type
177178
annotate_inputs(anchors.inputs, input_act_qspec)

backends/cadence/aot/quantizer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
SourcePartition,
2222
)
2323
from torchao.quantization.pt2e import ObserverOrFakeQuantize
24+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2425

2526

2627
def quantize_tensor_multiplier(
@@ -88,8 +89,7 @@ def is_annotated(nodes: List[fx.Node]) -> bool:
8889
annotated = False
8990
for node in nodes:
9091
annotated = annotated or (
91-
"quantization_annotation" in node.meta
92-
and node.meta["quantization_annotation"]._annotated
92+
Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated
9393
)
9494
return annotated
9595

backends/cortex_m/test/test_replace_quant_nodes.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
QuantizationSpec,
2626
Quantizer,
2727
)
28+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2829

2930

3031
@dataclass(eq=True, frozen=True)
@@ -67,18 +68,15 @@ def annotate(self, model: GraphModule):
6768
]:
6869
continue
6970

70-
if (
71-
"quantization_annotation" in node.meta
72-
and node.meta["quantization_annotation"]._annotated
73-
):
71+
if Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated:
7472
continue
7573

7674
input_qspec_map = {
7775
node.args[0]: config.input_activation,
7876
node.args[1]: config.input_activation,
7977
}
8078

81-
node.meta["quantization_annotation"] = QuantizationAnnotation(
79+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
8280
input_qspec_map=input_qspec_map,
8381
output_qspec=config.output_activation,
8482
_annotated=True,

backends/example/example_operators/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
8+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
89

910

1011
def _nodes_are_annotated(node_list):
1112
for node in node_list:
12-
quantization_annotation = node.meta.get("quantization_annotation", None)
13+
quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, None)
1314
if not quantization_annotation:
1415
return False
1516
if quantization_annotation._annotated:
@@ -23,11 +24,11 @@ def _annotate_nodes(node_tuples, quant_spec, input_node=False):
2324
for node_tuple in node_tuples:
2425
node = node_tuple[0]
2526
quant_annotation = node.meta.get(
26-
"quantization_annotation", QuantizationAnnotation(_annotated=True)
27+
Q_ANNOTATION_KEY, QuantizationAnnotation(_annotated=True)
2728
)
2829
if input_node:
2930
input_node = node_tuple[1]
3031
quant_annotation.input_qspec_map[input_node] = quant_spec
3132
else:
3233
quant_annotation.output_qspec = quant_spec
33-
node.meta["quantization_annotation"] = quant_annotation
34+
node.meta[Q_ANNOTATION_KEY] = quant_annotation

backends/mediatek/quantizer/annotator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
annotate_output_qspec as _annotate_output_qspec,
2222
QuantizationAnnotation,
2323
)
24+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2425

2526
from .qconfig import QuantizationConfig
2627

@@ -57,12 +58,12 @@ def _is_annotated(node: Node):
5758
return True if any of the node
5859
is annotated, otherwise return False
5960
"""
60-
KEY = "quantization_annotation"
61+
KEY = Q_ANNOTATION_KEY
6162
return KEY in node.meta and node.meta[KEY]._annotated
6263

6364

6465
def _mark_as_annotated(nodes: List[Node]):
65-
KEY = "quantization_annotation"
66+
KEY = Q_ANNOTATION_KEY
6667
for node in nodes:
6768
if KEY not in node.meta:
6869
node.meta[KEY] = QuantizationAnnotation()

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
QuantizationSpec,
5252
Quantizer,
5353
)
54+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
5455

5556

5657
class NeutronAtenQuantizer(Quantizer):
@@ -92,7 +93,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
9293

9394
for output, *custom_spec in anchors.output:
9495
# pyre-ignore[16]: no attribute
95-
output.meta["quantization_annotation"] = QuantizationAnnotation(
96+
output.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
9697
# pyre-ignore[6]: incompatible parameter type
9798
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
9899
_annotated=True,
@@ -108,7 +109,7 @@ def annotate_inputs(
108109
for node, idx, *custom_spec in inputs:
109110
# pyre-ignore[16]: no attribute
110111
annotation = node.meta.get(
111-
"quantization_annotation",
112+
Q_ANNOTATION_KEY,
112113
QuantizationAnnotation(_annotated=True),
113114
)
114115
arg = (
@@ -122,21 +123,21 @@ def annotate_inputs(
122123
custom_spec[0] if custom_spec else spec
123124
)
124125
# pyre-ignore[16]: no attribute
125-
node.meta["quantization_annotation"] = annotation
126+
node.meta[Q_ANNOTATION_KEY] = annotation
126127

127128
def annotate_weights_or_biases(
128129
weights_or_biases: List[Tuple[fx.Node, int]],
129130
spec: Optional[QuantizationSpec],
130131
) -> None:
131132
for node, idx, *custom_spec in weights_or_biases:
132133
annotation = node.meta.get(
133-
"quantization_annotation",
134+
Q_ANNOTATION_KEY,
134135
QuantizationAnnotation(_annotated=True),
135136
)
136137
annotation.input_qspec_map[node.args[idx]] = (
137138
custom_spec[0] if custom_spec else spec
138139
)
139-
node.meta["quantization_annotation"] = annotation
140+
node.meta[Q_ANNOTATION_KEY] = annotation
140141

141142
# pyre-ignore[6]: incompatible parameter type
142143
annotate_inputs(anchors.inputs, input_act_qspec)

backends/nxp/quantizer/patterns.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
FixedQParamsQuantizationSpec,
2020
SharedQuantizationSpec,
2121
)
22+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2223

2324

2425
@dataclass
@@ -90,7 +91,7 @@ def get_anchors(
9091
prev_node = fused_partition[0].input_nodes[0]
9192

9293
# Previous node was not quantized => we are not able to share q-params
93-
if "quantization_annotation" not in prev_node.meta:
94+
if Q_ANNOTATION_KEY not in prev_node.meta:
9495
return None
9596

9697
qspec = SharedQuantizationSpec(prev_node)

backends/nxp/quantizer/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
SourcePartition,
2020
)
2121
from torchao.quantization.pt2e import ObserverOrFakeQuantize
22+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2223

2324

2425
def is_annotated(nodes: List[fx.Node]) -> bool:
2526
annotated = False
2627
for node in nodes:
2728
annotated = annotated or (
28-
"quantization_annotation" in node.meta
29-
and node.meta["quantization_annotation"]._annotated
29+
Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated
3030
)
3131
return annotated
3232

backends/openvino/quantizer/quantizer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
Quantizer,
3131
SharedQuantizationSpec,
3232
)
33-
34-
QUANT_ANNOTATION_KEY = "quantization_annotation"
33+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
3534

3635

3736
class QuantizationMode(Enum):
@@ -174,8 +173,8 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
174173
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
175174

176175
for node, annotation in node_vs_torch_annotation.items():
177-
assert QUANT_ANNOTATION_KEY not in node.meta
178-
node.meta[QUANT_ANNOTATION_KEY] = annotation
176+
assert Q_ANNOTATION_KEY not in node.meta
177+
node.meta[Q_ANNOTATION_KEY] = annotation
179178
return model
180179

181180
@staticmethod

0 commit comments

Comments
 (0)