@@ -250,7 +250,9 @@ def annotate_masked_fill(node: Node, quantization_config: QuantizationConfig) ->
250250 )
251251
252252
253- @register_annotator ([torch .ops .aten .mul , torch .ops .aten .mul .Tensor ])
253+ @register_annotator (
254+ [torch .ops .aten .mul , torch .ops .aten .mul .Tensor , torch .ops .aten .mul_ .Tensor ]
255+ )
254256def annotate_mul (node : Node , quantization_config : QuantizationConfig ) -> None :
255257 annotate_binary (node , quantization_config )
256258
@@ -606,9 +608,35 @@ def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None
606608
607609@register_annotator ([torch .ops .aten .slice .Tensor ])
608610def annotate_slice (node : Node , quantization_config : QuantizationConfig ) -> None :
611+ if _is_annotated ([node ]) or not _is_float_tensor (node ):
612+ return
609613 annotate_single_in_single_out (node , quantization_config )
610614
611615
616+ @register_annotator ([torch .ops .aten .slice_scatter .default ])
617+ def annotate_slice_scatter (node : Node , quantization_config : QuantizationConfig ) -> None :
618+ if _is_annotated ([node ]):
619+ return
620+
621+ input_act_qspec = quantization_config .input_activation
622+ output_act_qspec = quantization_config .output_activation
623+
624+ input_qspec_map = {}
625+ input_act0 = node .args [0 ]
626+ if isinstance (input_act0 , Node ):
627+ input_qspec_map [input_act0 ] = input_act_qspec
628+
629+ input_act1 = node .args [1 ]
630+ if isinstance (input_act1 , Node ):
631+ input_qspec_map [input_act1 ] = input_act_qspec
632+
633+ node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
634+ input_qspec_map = input_qspec_map ,
635+ output_qspec = output_act_qspec ,
636+ _annotated = True ,
637+ )
638+
639+
612640@register_annotator ([torch .ops .aten .sqrt .default ])
613641def annotate_sqrt (node : Node , quantization_config : QuantizationConfig ) -> None :
614642 annotate_single_in_single_out (node , quantization_config )
@@ -801,16 +829,17 @@ def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> N
801829
802830@register_annotator ([torch .ops .aten .index .Tensor ])
803831def annotate_index (node : Node , quantization_config : QuantizationConfig ) -> None :
832+ if _is_annotated ([node ]) or not _is_float_tensor (node ):
833+ return
804834 annotate_in_out_obs_sharing_op (node , quantization_config )
805- if not _is_annotated ([node ]):
806- input_qspec_map = {}
807- input = node .args [0 ]
808- input_qspec_map [input ] = quantization_config .input_activation
809- node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
810- input_qspec_map = input_qspec_map ,
811- output_qspec = SharedQuantizationSpec ((input , node )),
812- _annotated = True ,
813- )
835+ input_qspec_map = {}
836+ input = node .args [0 ]
837+ input_qspec_map [input ] = quantization_config .input_activation
838+ node .meta [QUANT_ANNOTATION_KEY ] = QuantizationAnnotation (
839+ input_qspec_map = input_qspec_map ,
840+ output_qspec = SharedQuantizationSpec ((input , node )),
841+ _annotated = True ,
842+ )
814843
815844
816845@register_annotator (
@@ -1270,7 +1299,7 @@ def annotate_where(node: Node, quantization_config: QuantizationConfig) -> None:
12701299 )
12711300
12721301
1273- @register_annotator ([torch .ops .aten .zeros .default ])
1302+ @register_annotator ([torch .ops .aten .zeros .default , torch . ops . aten . zeros_like . default ])
12741303def annotate_zeros (node : Node , quantization_config : QuantizationConfig ) -> None :
12751304 if _is_annotated ([node ]) or not _is_float_tensor (node ):
12761305 return
0 commit comments