2525)
2626from  torchao .quantization .pt2e .quantizer .quantizer  import  Q_ANNOTATION_KEY 
2727
28+ from  .observers .concat_observer  import  ConcatObserver 
29+ 
2830from  .qconfig  import  (
2931    get_16a16w_qnn_ptq_config ,
3032    get_16a4w_qnn_qat_config ,
@@ -691,7 +693,7 @@ def annotate_sign(node: Node, quantization_config: QuantizationConfig) -> None:
691693
692694@register_annotator ([torch .ops .aten .slice .Tensor ]) 
693695def  annotate_slice (node : Node , quantization_config : QuantizationConfig ) ->  None :
694-     annotate_single_in_single_out (node , quantization_config )
696+     annotate_single_in_share_out (node , quantization_config )
695697
696698
697699@register_annotator ([torch .ops .aten .slice_scatter .default ]) 
@@ -1277,31 +1279,40 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
12771279
12781280@register_annotator ([torch .ops .aten .cat .default , torch .ops .aten .concat .default ]) 
12791281def  annotate_cat (node : Node , quantization_config : QuantizationConfig ) ->  None :
1280-     input_nodes  =  node .args [0 ]
12811282    if  _is_annotated ([node ]) or  not  _is_float_tensor (node ):
12821283        return 
12831284
1284-     assert  isinstance (input_nodes , Sequence )
1285- 
1286-     first_input_node  =  input_nodes [0 ]
1287-     input_qspec_map  =  {}
1288-     assert  isinstance (first_input_node , Node )
1289-     assert  isinstance (node , Node )
1290-     if  _is_float_tensor (first_input_node ):
1291-         input_qspec_map [first_input_node ] =  quantization_config .input_activation 
1292-         share_qparams_with_input_act0_qspec  =  SharedQuantizationSpec (
1293-             (first_input_node , node )
1294-         )
1295- 
1296-     for  input_node  in  input_nodes [1 :]:
1297-         if  input_node  not  in   input_qspec_map :
1298-             assert  isinstance (input_node , Node )
1299-             if  _is_float_tensor (input_node ):
1300-                 input_qspec_map [input_node ] =  share_qparams_with_input_act0_qspec 
1301- 
1285+     input_qspec_map , input_nodes  =  {}, node .args [0 ]
1286+     for  input  in  input_nodes :
1287+         input_qspec  =  input .meta .get (Q_ANNOTATION_KEY , None )
1288+         if  (
1289+             # placeholder 
1290+             input_qspec  is  None 
1291+             or 
1292+             # keep shared qspec here for propagation the data range 
1293+             # without introducing extra requantizations 
1294+             not  isinstance (input_qspec .output_qspec , SharedQuantizationSpec )
1295+         ):
1296+             input_qspec_map [input ] =  quantization_config .input_activation 
1297+ 
1298+     output_qspec  =  QuantizationSpec (
1299+         dtype = quantization_config .output_activation .dtype ,
1300+         qscheme = quantization_config .output_activation .qscheme ,
1301+         quant_max = quantization_config .output_activation .quant_max ,
1302+         quant_min = quantization_config .output_activation .quant_min ,
1303+         observer_or_fake_quant_ctr = ConcatObserver .with_args (
1304+             # we need to know the concat node in order to hack all the input observers' data range 
1305+             # since deep copy of fake tensor (node.meta["val"]) is inhibited 
1306+             # we could only ship grap & node name and perform postprocess inside observer currently 
1307+             ** {
1308+                 "node_name" : node .name ,
1309+                 "graph" : node .graph ,
1310+             }
1311+         ),
1312+     )
13021313    node .meta [Q_ANNOTATION_KEY ] =  QuantizationAnnotation (
13031314        input_qspec_map = input_qspec_map ,
1304-         output_qspec = share_qparams_with_input_act0_qspec ,
1315+         output_qspec = output_qspec ,
13051316        _annotated = True ,
13061317    )
13071318
@@ -1345,6 +1356,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
13451356    input_act  =  node .args [0 ]
13461357    assert  isinstance (input_act , Node )
13471358    input_qspec_map [input_act ] =  quantization_config .input_activation 
1359+     share_qparams_with_input_node_qspec  =  SharedQuantizationSpec ((input_act , node ))
13481360
13491361    node .meta [Q_ANNOTATION_KEY ] =  QuantizationAnnotation (
13501362        input_qspec_map = input_qspec_map ,
@@ -1353,7 +1365,7 @@ def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
13531365
13541366    for  user  in  node .users :
13551367        user .meta [Q_ANNOTATION_KEY ] =  QuantizationAnnotation (
1356-             output_qspec = quantization_config . output_activation ,
1368+             output_qspec = share_qparams_with_input_node_qspec ,
13571369            _annotated = True ,
13581370        )
13591371
0 commit comments