Skip to content

Commit b9d6de8

Browse files
cccclaifacebook-github-bot
authored andcommitted
Use Q_ANNOTATION_KEY
Summary: pytorch/ao#2525 introduced Q_ANNOTATION_KEY to avoid manually typing "quantization_annotation". Trying to apply it in our codebase Reviewed By: jerryzh168 Differential Revision: D78193037
1 parent cea9b23 commit b9d6de8

File tree

13 files changed

+109
-99
lines changed

13 files changed

+109
-99
lines changed

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,23 @@
1818
from torch.fx import GraphModule, Node
1919

2020
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
21+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2122

2223

2324
def is_annotated(node: Node) -> bool:
2425
"""Given a node return whether the node is annotated."""
2526
return (
26-
"quantization_annotation" in node.meta
27+
Q_ANNOTATION_KEY in node.meta
2728
and cast(
28-
QuantizationAnnotation, node.meta["quantization_annotation"]
29+
QuantizationAnnotation, node.meta[Q_ANNOTATION_KEY]
2930
)._annotated
3031
)
3132

3233

3334
def is_output_annotated(node: Node) -> bool:
3435
"""Given a node, return whether the output of the node is annotated."""
35-
if "quantization_annotation" in node.meta:
36-
annotation = cast(QuantizationAnnotation, node.meta["quantization_annotation"])
36+
if Q_ANNOTATION_KEY in node.meta:
37+
annotation = cast(QuantizationAnnotation, node.meta[Q_ANNOTATION_KEY])
3738
return annotation._annotated and annotation.output_qspec is not None
3839
else:
3940
return False
@@ -43,9 +44,9 @@ def mark_node_as_annotated(node: Node) -> None:
4344
"""Marks node as annotated. If needed, an empty QuantizationAnnotation is added
4445
to the quantization_annotation node meta entry.
4546
"""
46-
if "quantization_annotation" not in node.meta:
47-
node.meta["quantization_annotation"] = QuantizationAnnotation()
48-
node.meta["quantization_annotation"]._annotated = True
47+
if Q_ANNOTATION_KEY not in node.meta:
48+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
49+
node.meta[Q_ANNOTATION_KEY]._annotated = True
4950

5051

5152
def is_ok_for_quantization(node: Node, gm: GraphModule):

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
is_annotated,
3030
no_outside_users,
3131
)
32+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
3233

3334
from torch import fx
3435

@@ -127,7 +128,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
127128

128129
for output, *custom_spec in anchors.output:
129130
# pyre-ignore[16]: no attribute
130-
output.meta["quantization_annotation"] = QuantizationAnnotation(
131+
output.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
131132
# pyre-ignore[6]: incompatible parameter type
132133
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
133134
_annotated=True,
@@ -143,7 +144,7 @@ def annotate_inputs(
143144
for node, idx, *custom_spec in inputs:
144145
# pyre-ignore[16]: no attribute
145146
annotation = node.meta.get(
146-
"quantization_annotation",
147+
Q_ANNOTATION_KEY,
147148
QuantizationAnnotation(_annotated=True),
148149
)
149150
arg = (
@@ -157,21 +158,21 @@ def annotate_inputs(
157158
custom_spec[0] if custom_spec else spec
158159
)
159160
# pyre-ignore[16]: no attribute
160-
node.meta["quantization_annotation"] = annotation
161+
node.meta[Q_ANNOTATION_KEY] = annotation
161162

162163
def annotate_weights_or_biases(
163164
weights_or_biases: List[Tuple[fx.Node, int]],
164165
spec: Optional[QuantizationSpec],
165166
) -> None:
166167
for node, idx, *custom_spec in weights_or_biases:
167168
annotation = node.meta.get(
168-
"quantization_annotation",
169+
Q_ANNOTATION_KEY,
169170
QuantizationAnnotation(_annotated=True),
170171
)
171172
annotation.input_qspec_map[node.args[idx]] = (
172173
custom_spec[0] if custom_spec else spec
173174
)
174-
node.meta["quantization_annotation"] = annotation
175+
node.meta[Q_ANNOTATION_KEY] = annotation
175176

176177
# pyre-ignore[6]: incompatible parameter type
177178
annotate_inputs(anchors.inputs, input_act_qspec)

backends/cadence/aot/quantizer/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
SourcePartition,
2222
)
2323
from torchao.quantization.pt2e import ObserverOrFakeQuantize
24+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2425

2526

2627
def quantize_tensor_multiplier(
@@ -88,8 +89,8 @@ def is_annotated(nodes: List[fx.Node]) -> bool:
8889
annotated = False
8990
for node in nodes:
9091
annotated = annotated or (
91-
"quantization_annotation" in node.meta
92-
and node.meta["quantization_annotation"]._annotated
92+
Q_ANNOTATION_KEY in node.meta
93+
and node.meta[Q_ANNOTATION_KEY]._annotated
9394
)
9495
return annotated
9596

backends/cortex_m/test/test_replace_quant_nodes.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
QuantizationSpec,
2626
Quantizer,
2727
)
28+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2829

2930

3031
@dataclass(eq=True, frozen=True)
@@ -68,8 +69,8 @@ def annotate(self, model: GraphModule):
6869
continue
6970

7071
if (
71-
"quantization_annotation" in node.meta
72-
and node.meta["quantization_annotation"]._annotated
72+
Q_ANNOTATION_KEY in node.meta
73+
and node.meta[Q_ANNOTATION_KEY]._annotated
7374
):
7475
continue
7576

@@ -78,7 +79,7 @@ def annotate(self, model: GraphModule):
7879
node.args[1]: config.input_activation,
7980
}
8081

81-
node.meta["quantization_annotation"] = QuantizationAnnotation(
82+
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
8283
input_qspec_map=input_qspec_map,
8384
output_qspec=config.output_activation,
8485
_annotated=True,

backends/example/example_operators/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
8+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
89

910

1011
def _nodes_are_annotated(node_list):
1112
for node in node_list:
12-
quantization_annotation = node.meta.get("quantization_annotation", None)
13+
quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, None)
1314
if not quantization_annotation:
1415
return False
1516
if quantization_annotation._annotated:
@@ -23,11 +24,11 @@ def _annotate_nodes(node_tuples, quant_spec, input_node=False):
2324
for node_tuple in node_tuples:
2425
node = node_tuple[0]
2526
quant_annotation = node.meta.get(
26-
"quantization_annotation", QuantizationAnnotation(_annotated=True)
27+
Q_ANNOTATION_KEY, QuantizationAnnotation(_annotated=True)
2728
)
2829
if input_node:
2930
input_node = node_tuple[1]
3031
quant_annotation.input_qspec_map[input_node] = quant_spec
3132
else:
3233
quant_annotation.output_qspec = quant_spec
33-
node.meta["quantization_annotation"] = quant_annotation
34+
node.meta[Q_ANNOTATION_KEY] = quant_annotation

backends/mediatek/quantizer/annotator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
annotate_output_qspec as _annotate_output_qspec,
2222
QuantizationAnnotation,
2323
)
24+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2425

2526
from .qconfig import QuantizationConfig
2627

@@ -57,12 +58,12 @@ def _is_annotated(node: Node):
5758
return True if any of the node
5859
is annotated, otherwise return False
5960
"""
60-
KEY = "quantization_annotation"
61+
KEY = Q_ANNOTATION_KEY
6162
return KEY in node.meta and node.meta[KEY]._annotated
6263

6364

6465
def _mark_as_annotated(nodes: List[Node]):
65-
KEY = "quantization_annotation"
66+
KEY = Q_ANNOTATION_KEY
6667
for node in nodes:
6768
if KEY not in node.meta:
6869
node.meta[KEY] = QuantizationAnnotation()

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
QuantizationSpec,
4646
Quantizer,
4747
)
48+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
4849

4950

5051
class NeutronAtenQuantizer(Quantizer):
@@ -86,7 +87,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
8687

8788
for output, *custom_spec in anchors.output:
8889
# pyre-ignore[16]: no attribute
89-
output.meta["quantization_annotation"] = QuantizationAnnotation(
90+
output.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
9091
# pyre-ignore[6]: incompatible parameter type
9192
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
9293
_annotated=True,
@@ -102,7 +103,7 @@ def annotate_inputs(
102103
for node, idx, *custom_spec in inputs:
103104
# pyre-ignore[16]: no attribute
104105
annotation = node.meta.get(
105-
"quantization_annotation",
106+
Q_ANNOTATION_KEY,
106107
QuantizationAnnotation(_annotated=True),
107108
)
108109
arg = (
@@ -116,21 +117,21 @@ def annotate_inputs(
116117
custom_spec[0] if custom_spec else spec
117118
)
118119
# pyre-ignore[16]: no attribute
119-
node.meta["quantization_annotation"] = annotation
120+
node.meta[Q_ANNOTATION_KEY] = annotation
120121

121122
def annotate_weights_or_biases(
122123
weights_or_biases: List[Tuple[fx.Node, int]],
123124
spec: Optional[QuantizationSpec],
124125
) -> None:
125126
for node, idx, *custom_spec in weights_or_biases:
126127
annotation = node.meta.get(
127-
"quantization_annotation",
128+
Q_ANNOTATION_KEY,
128129
QuantizationAnnotation(_annotated=True),
129130
)
130131
annotation.input_qspec_map[node.args[idx]] = (
131132
custom_spec[0] if custom_spec else spec
132133
)
133-
node.meta["quantization_annotation"] = annotation
134+
node.meta[Q_ANNOTATION_KEY] = annotation
134135

135136
# pyre-ignore[6]: incompatible parameter type
136137
annotate_inputs(anchors.inputs, input_act_qspec)

backends/nxp/quantizer/patterns.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
FixedQParamsQuantizationSpec,
2020
SharedQuantizationSpec,
2121
)
22+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2223

2324

2425
@dataclass
@@ -90,7 +91,7 @@ def get_anchors(
9091
prev_node = fused_partition[0].input_nodes[0]
9192

9293
# Previous node was not quantized => we are not able to share q-params
93-
if "quantization_annotation" not in prev_node.meta:
94+
if Q_ANNOTATION_KEY not in prev_node.meta:
9495
return None
9596

9697
qspec = SharedQuantizationSpec(prev_node)

backends/nxp/quantizer/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
SourcePartition,
2020
)
2121
from torchao.quantization.pt2e import ObserverOrFakeQuantize
22+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
2223

2324

2425
def is_annotated(nodes: List[fx.Node]) -> bool:
2526
annotated = False
2627
for node in nodes:
2728
annotated = annotated or (
28-
"quantization_annotation" in node.meta
29-
and node.meta["quantization_annotation"]._annotated
29+
Q_ANNOTATION_KEY in node.meta
30+
and node.meta[Q_ANNOTATION_KEY]._annotated
3031
)
3132
return annotated
3233

backends/openvino/quantizer/quantizer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
Quantizer,
3131
SharedQuantizationSpec,
3232
)
33-
34-
QUANT_ANNOTATION_KEY = "quantization_annotation"
33+
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
3534

3635

3736
class QuantizationMode(Enum):
@@ -174,8 +173,8 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
174173
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)
175174

176175
for node, annotation in node_vs_torch_annotation.items():
177-
assert QUANT_ANNOTATION_KEY not in node.meta
178-
node.meta[QUANT_ANNOTATION_KEY] = annotation
176+
assert Q_ANNOTATION_KEY not in node.meta
177+
node.meta[Q_ANNOTATION_KEY] = annotation
179178
return model
180179

181180
@staticmethod

0 commit comments

Comments
 (0)