@@ -1041,13 +1041,13 @@ def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...
10411041def  quantized_relu_asym8u_asym8u_per_tensor () ->  torch .Tensor : ...
10421042
10431043
1044- @impl (m , "requantize" ) 
1045- def  requantize (
1044+ @impl (m , "requantize.per_tensor " ) 
1045+ def  requantize_per_tensor (
10461046    input : torch .Tensor ,
1047-     in_scale : torch . Tensor ,
1048-     in_zero_point : torch . Tensor ,
1049-     out_scale : torch . Tensor ,
1050-     out_zero_point : torch . Tensor ,
1047+     in_scale : float ,
1048+     in_zero_point : int ,
1049+     out_scale : float ,
1050+     out_zero_point : int ,
10511051    dtype : ScalarType ,
10521052) ->  torch .Tensor :
10531053    if  dtype  in  qdtype_map :
@@ -1056,11 +1056,6 @@ def requantize(
10561056            torch .dequantize (input ), out_scale , out_zero_point , qdtype_map [dtype ]
10571057        )
10581058
1059-     # For in_scale or out_scale other than scalar, it requires quant/dequant 
1060-     # per channel, but the channel dimension value is missing 
1061-     if  in_scale .numel () >  1  or  out_scale .numel () >  1 :
1062-         raise  NotImplementedError ("Only scalar scales are supported" )
1063- 
10641059    quant_min  =  torch .iinfo (input .dtype ).min 
10651060    quant_max  =  torch .iinfo (input .dtype ).max 
10661061    # pyre-fixme[6]: This dtype is actually the right one. 
@@ -1070,14 +1065,14 @@ def requantize(
10701065    return  torch .ops .quantized_decomposed .quantize_per_tensor (
10711066        torch .ops .quantized_decomposed .dequantize_per_tensor (
10721067            input ,
1073-             in_scale . flatten ()[ 0 ] ,
1074-             in_zero_point . flatten ()[ 0 ] ,
1068+             in_scale ,
1069+             in_zero_point ,
10751070            quant_min ,
10761071            quant_max ,
10771072            input .dtype ,
10781073        ),
1079-         out_scale . flatten ()[ 0 ] ,
1080-         out_zero_point . flatten ()[ 0 ] ,
1074+         out_scale ,
1075+         out_zero_point ,
10811076        out_quant_min ,
10821077        out_quant_max ,
10831078        dtype ,
0 commit comments