3535 no_outside_users ,
3636)
3737from torch import fx
38+ from torch .ao .quantization .quantizer .utils import _annotate_output_qspec
3839from torchao .quantization .pt2e import HistogramObserver , MinMaxObserver
3940from torchao .quantization .pt2e .quantizer import (
4041 ComposableQuantizer ,
@@ -224,6 +225,8 @@ def transform_for_annotation(
224225 return pass_runner (model ).graph_module
225226
226227 def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
228+ self ._annotate_inputs (model )
229+
227230 nodes = list (model .graph .nodes )
228231 for node in nodes :
229232 if (
@@ -239,5 +242,25 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
239242
240243 return model
241244
245+ def _is_input_annotated (self , node : fx .Node ) -> bool :
246+ return (
247+ "quantization_annotation" in node .meta
248+ and node .meta ["quantization_annotation" ]._annotated
249+ )
250+
251+ def _mark_input_node_as_annotated (self , node : fx .Node ) -> None :
252+ if "quantization_annotation" not in node .meta :
253+ node .meta ["quantization_annotation" ] = QuantizationAnnotation ()
254+ node .meta ["quantization_annotation" ]._annotated = True
255+
256+ def _annotate_inputs (self , model : fx .GraphModule ):
257+ for node in model .graph .nodes :
258+ if self ._is_input_annotated (node ):
259+ continue
260+
261+ if node .op == "placeholder" and len (node .users ) > 0 :
262+ _annotate_output_qspec (node , act_qspec )
263+ self ._mark_input_node_as_annotated (node )
264+
242265 def validate (self , model : torch .fx .GraphModule ) -> None :
243266 return super ().validate (model )
0 commit comments