4141 no_outside_users ,
4242)
4343from torch import fx
44+ from torch .ao .quantization .quantizer .utils import _annotate_output_qspec
4445from torchao .quantization .pt2e import HistogramObserver , MinMaxObserver
4546from torchao .quantization .pt2e .quantizer import (
4647 ComposableQuantizer ,
@@ -237,6 +238,8 @@ def transform_for_annotation(
237238 return pass_runner (model ).graph_module
238239
239240 def annotate (self , model : torch .fx .GraphModule ) -> torch .fx .GraphModule :
241+ self ._annotate_inputs (model )
242+
240243 nodes = list (model .graph .nodes )
241244 for node in nodes :
242245 if (
@@ -252,5 +255,25 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
252255
253256 return model
254257
258+ def _is_input_annotated (self , node : fx .Node ) -> bool :
259+ return (
260+ "quantization_annotation" in node .meta
261+ and node .meta ["quantization_annotation" ]._annotated
262+ )
263+
264+ def _mark_input_node_as_annotated (self , node : fx .Node ) -> None :
265+ if "quantization_annotation" not in node .meta :
266+ node .meta ["quantization_annotation" ] = QuantizationAnnotation ()
267+ node .meta ["quantization_annotation" ]._annotated = True
268+
269+ def _annotate_inputs (self , model : fx .GraphModule ):
270+ for node in model .graph .nodes :
271+ if self ._is_input_annotated (node ):
272+ continue
273+
274+ if node .op == "placeholder" and len (node .users ) > 0 :
275+ _annotate_output_qspec (node , act_qspec )
276+ self ._mark_input_node_as_annotated (node )
277+
255278 def validate (self , model : torch .fx .GraphModule ) -> None :
256279 return super ().validate (model )
0 commit comments