4141 no_outside_users ,
4242)
4343from torch import fx
44+ from torch .ao .quantization .quantizer .utils import _annotate_output_qspec
4445from torchao .quantization .pt2e import HistogramObserver , MinMaxObserver
4546from torchao .quantization .pt2e .quantizer import (
4647 ComposableQuantizer ,
@@ -236,6 +237,8 @@ def transform_for_annotation(
236237 return pass_runner (model ).graph_module
237238
238239 def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
240+ self ._annotate_inputs (model )
241+
239242 nodes = list (model .graph .nodes )
240243 for node in nodes :
241244 if (
@@ -251,5 +254,25 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
251254
252255 return model
253256
257+ def _is_input_annotated (self , node : fx .Node ) -> bool :
258+ return (
259+ "quantization_annotation" in node .meta
260+ and node .meta ["quantization_annotation" ]._annotated
261+ )
262+
263+ def _mark_input_node_as_annotated (self , node : fx .Node ) -> None :
264+ if "quantization_annotation" not in node .meta :
265+ node .meta ["quantization_annotation" ] = QuantizationAnnotation ()
266+ node .meta ["quantization_annotation" ]._annotated = True
267+
268+ def _annotate_inputs (self , model : fx .GraphModule ):
269+ for node in model .graph .nodes :
270+ if self ._is_input_annotated (node ):
271+ continue
272+
273+ if node .op == "placeholder" and len (node .users ) > 0 :
274+ _annotate_output_qspec (node , act_qspec )
275+ self ._mark_input_node_as_annotated (node )
276+
254277 def validate (self , model : torch .fx .GraphModule ) -> None :
255278 return super ().validate (model )
0 commit comments