4545 QuantizationSpec ,
4646 Quantizer ,
4747)
48+ from torchao .quantization .pt2e .quantizer .quantizer import Q_ANNOTATION_KEY
4849
4950
5051class NeutronAtenQuantizer (Quantizer ):
@@ -86,7 +87,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
8687
8788 for output , * custom_spec in anchors .output :
8889 # pyre-ignore[16]: no attribute
89- output .meta ["quantization_annotation" ] = QuantizationAnnotation (
90+ output .meta [Q_ANNOTATION_KEY ] = QuantizationAnnotation (
9091 # pyre-ignore[6]: incompatible parameter type
9192 output_qspec = (custom_spec [0 ] if custom_spec else output_act_qspec ),
9293 _annotated = True ,
@@ -102,7 +103,7 @@ def annotate_inputs(
102103 for node , idx , * custom_spec in inputs :
103104 # pyre-ignore[16]: no attribute
104105 annotation = node .meta .get (
105- "quantization_annotation" ,
106+ Q_ANNOTATION_KEY ,
106107 QuantizationAnnotation (_annotated = True ),
107108 )
108109 arg = (
@@ -116,21 +117,21 @@ def annotate_inputs(
116117 custom_spec [0 ] if custom_spec else spec
117118 )
118119 # pyre-ignore[16]: no attribute
119- node .meta ["quantization_annotation" ] = annotation
120+ node .meta [Q_ANNOTATION_KEY ] = annotation
120121
121122 def annotate_weights_or_biases (
122123 weights_or_biases : List [Tuple [fx .Node , int ]],
123124 spec : Optional [QuantizationSpec ],
124125 ) -> None :
125126 for node , idx , * custom_spec in weights_or_biases :
126127 annotation = node .meta .get (
127- "quantization_annotation" ,
128+ Q_ANNOTATION_KEY ,
128129 QuantizationAnnotation (_annotated = True ),
129130 )
130131 annotation .input_qspec_map [node .args [idx ]] = (
131132 custom_spec [0 ] if custom_spec else spec
132133 )
133- node .meta ["quantization_annotation" ] = annotation
134+ node .meta [Q_ANNOTATION_KEY ] = annotation
134135
135136 # pyre-ignore[6]: incompatible parameter type
136137 annotate_inputs (anchors .inputs , input_act_qspec )
0 commit comments