@@ -1041,13 +1041,13 @@ def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
10411041def quantized_relu_asym8u_asym8u_per_tensor () -> torch .Tensor : ...
10421042
10431043
1044- @impl (m , "requantize" )
1045- def requantize (
1044+ @impl (m , "requantize.per_tensor " )
1045+ def requantize_per_tensor (
10461046 input : torch .Tensor ,
1047- in_scale : torch . Tensor ,
1048- in_zero_point : torch . Tensor ,
1049- out_scale : torch . Tensor ,
1050- out_zero_point : torch . Tensor ,
1047+ in_scale : float ,
1048+ in_zero_point : int ,
1049+ out_scale : float ,
1050+ out_zero_point : int ,
10511051 dtype : ScalarType ,
10521052) -> torch .Tensor :
10531053 if dtype in qdtype_map :
@@ -1056,11 +1056,6 @@ def requantize(
10561056 torch .dequantize (input ), out_scale , out_zero_point , qdtype_map [dtype ]
10571057 )
10581058
1059- # For in_scale or out_scale other than scalar, it requires quant/dequant
1060- # per channel, but the channel dimension value is missing
1061- if in_scale .numel () > 1 or out_scale .numel () > 1 :
1062- raise NotImplementedError ("Only scalar scales are supported" )
1063-
10641059 quant_min = torch .iinfo (input .dtype ).min
10651060 quant_max = torch .iinfo (input .dtype ).max
10661061 # pyre-fixme[6]: This dtype is actually the right one.
@@ -1070,14 +1065,14 @@ def requantize(
10701065 return torch .ops .quantized_decomposed .quantize_per_tensor (
10711066 torch .ops .quantized_decomposed .dequantize_per_tensor (
10721067 input ,
1073- in_scale . flatten ()[ 0 ] ,
1074- in_zero_point . flatten ()[ 0 ] ,
1068+ in_scale ,
1069+ in_zero_point ,
10751070 quant_min ,
10761071 quant_max ,
10771072 input .dtype ,
10781073 ),
1079- out_scale . flatten ()[ 0 ] ,
1080- out_zero_point . flatten ()[ 0 ] ,
1074+ out_scale ,
1075+ out_zero_point ,
10811076 out_quant_min ,
10821077 out_quant_max ,
10831078 dtype ,
0 commit comments