66
77# pyre-strict
88
9-
109from typing import Callable
1110
1211import torch
12+ import torch .nn as nn
13+ import torch .nn .functional as F
1314
1415from executorch .exir .scalar_type import ScalarType
16+ from on_device_ai .Assistant .Jarvis .nn .roi_align_utils import convertBoxPosToTuringConfig
1517from torch .library import impl , Library
1618
17-
1819m = Library ("cadence" , "IMPL" , "CompositeExplicitAutograd" )
20+ torch .ops .load_library ("//executorch/kernels/quantized:custom_ops_generated_lib" )
1921
2022qdtype_map : dict [ScalarType , torch .dtype ] = {
2123 ScalarType .QINT8 : torch .qint8 ,
@@ -38,7 +40,7 @@ def quantize_per_tensor(
3840
3941 Args:
4042 - input_tensor (Tensor): input tensor
41- - scale (float): Inverse of quantization scale. Derived from the ratio
43+ - scale (float): Quantization scale. Derived from the ratio
4244 between the min/max of the floating-point tensor and the
4345 min/max of the quantized range, and then inverted.
4446 - zero_point (int): The point which represents 0 in the quantized
@@ -64,7 +66,8 @@ def quantize_per_tensor(
6466 f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
6567 )
6668
67- quantized = torch .round (input_tensor * scale + zero_point ).to (dtype )
69+ inv_scale = 1.0 / scale
70+ quantized = torch .round (input_tensor * inv_scale + zero_point ).to (dtype )
6871 return torch .max (
6972 torch .min (quantized , torch .tensor (quant_max )),
7073 torch .tensor (quant_min ),
@@ -97,7 +100,7 @@ def dequantize_per_tensor(
97100 is already provided.
98101 - quant_max (int): The largest value in the quantized domain. Unused since scale
99102 is already provided.
100- - dtype (torch.dtype): The type of the output tensor. Must be a floating point type .
103+ - dtype (torch.dtype): The type of the input tensor.
101104 """
102105 supported_quant_types = [
103106 torch .int8 ,
@@ -108,23 +111,15 @@ def dequantize_per_tensor(
108111 ]
109112 if input_tensor .dtype not in supported_quant_types :
110113 raise ValueError (f"Input dtype must be one of { supported_quant_types } " )
111- supported_dequant_types = [
112- torch .float ,
113- torch .float32 ,
114- torch .float16 ,
115- torch .bfloat16 ,
116- ]
117- if dtype not in supported_dequant_types :
118- raise ValueError (
119- f"Unsupported dtype to dequantize to. Supported dtypes must be one of { supported_dequant_types } "
120- )
114+ if input_tensor .dtype != dtype :
115+ raise ValueError ("Input dtype must match dtype" )
121116
122117 # Needed to prevent underflow in cases where the zero_point is larger than
123118 # the quantized value.
124119 if not input_tensor .dtype .is_signed :
125120 input_tensor = input_tensor .to (torch .int32 )
126121
127- return (input_tensor - zero_point ). to ( dtype ) * scale
122+ return (( input_tensor - zero_point ) * scale ). to ( torch . float32 )
128123
129124
130125@impl (m , "quantized_add.per_tensor" )
@@ -180,12 +175,10 @@ def quantized_add_per_tensor(
180175 dequant_X = X_scale * (X - X_zero_point )
181176 dequant_Y = Y_scale * (Y - Y_zero_point )
182177
183- out_scale_inv = 1 / out_scale
184-
185178 # q_min/q_max are unused args
186179 return quantize_per_tensor (
187180 dequant_X + dequant_Y ,
188- out_scale_inv ,
181+ out_scale ,
189182 out_zero_point ,
190183 torch .iinfo (dtype ).min ,
191184 torch .iinfo (dtype ).max ,
@@ -260,7 +253,6 @@ def quantized_linear_common(
260253 - offset (Tensor): Unused
261254 """
262255 out_scale = - out_multiplier * (1 / (1 << 31 )) * (2 ** out_shift )
263- out_scale_inv = 1 / out_scale
264256
265257 N , K = weight .shape
266258
@@ -281,7 +273,7 @@ def quantized_linear_common(
281273 )
282274 return quantize_per_tensor (
283275 out ,
284- out_scale_inv ,
276+ out_scale ,
285277 out_zero_point ,
286278 torch .iinfo (dtype ).min ,
287279 torch .iinfo (dtype ).max ,
@@ -399,6 +391,17 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor:
399391def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor () -> torch .Tensor : ...
400392
401393
394+ @impl (m , "fully_connected" )
395+ def fully_connected (
396+ input_tensor : torch .Tensor ,
397+ weight : torch .Tensor ,
398+ bias : torch .Tensor ,
399+ ) -> torch .Tensor :
400+ if input_tensor .shape [0 ] != 1 :
401+ raise ValueError ("Fully connected linear only supports batch size of 1" )
402+ return F .linear (input_tensor , weight , bias )
403+
404+
402405@impl (m , "quantized_matmul" )
403406def quantized_matmul (
404407 X : torch .Tensor ,
@@ -538,15 +541,15 @@ def quantized_layer_norm_per_tensor(
538541 )
539542
540543 float_input_tensor = dequantize_per_tensor (
541- input_tensor , X_scale , X_zero_point , - 128 , 127 , torch . float32
544+ input_tensor , X_scale , X_zero_point , - 128 , 127 , input_tensor . dtype
542545 )
543546 out = torch .nn .functional .layer_norm (
544547 float_input_tensor , normalized_shape , weight , bias , eps = eps
545548 )
546549
547550 return quantize_per_tensor (
548551 out ,
549- 1 / output_scale ,
552+ output_scale ,
550553 output_zero_point ,
551554 torch .iinfo (input_tensor .dtype ).min ,
552555 torch .iinfo (input_tensor .dtype ).max ,
@@ -615,7 +618,7 @@ def quantized_conv_per_tensor(
615618
616619 return quantize_per_tensor (
617620 float_out ,
618- 1.0 / output_scale ,
621+ output_scale ,
619622 output_zero_point ,
620623 torch .iinfo (input_tensor .dtype ).min ,
621624 torch .iinfo (input_tensor .dtype ).max ,
@@ -942,7 +945,7 @@ def quantized_relu_common(
942945 if X .dtype not in supported_dtypes :
943946 raise ValueError (f"X dtype must be one of { supported_dtypes } . Got { X .dtype } " )
944947
945- out_scale = - out_multiplier * (1 / (1 << 31 )) * (2 ** out_shift )
948+ out_scale = 1.0 / ( - out_multiplier * (1 / (1 << 31 )) * (2 ** out_shift ) )
946949 dequantized_X = torch .where (X > X_zero_point , X - X_zero_point , torch .zeros_like (X ))
947950 return quantize_per_tensor (
948951 dequantized_X ,
@@ -1068,3 +1071,45 @@ def requantize(
10681071 out_quant_max ,
10691072 dtype ,
10701073 )
1074+
1075+
1076+ @impl (m , "roi_align_box_processor" )
1077+ def roi_align_box_processor (
1078+ rois : torch .Tensor ,
1079+ output_size_h : int ,
1080+ output_size_w : int ,
1081+ sampling_ratio : int ,
1082+ aligned : bool ,
1083+ ) -> torch .Tensor :
1084+ K = rois .shape [0 ]
1085+ turing_rois = []
1086+ for i in range (K ):
1087+ x1 = rois [i ][1 ].item ()
1088+ y1 = rois [i ][2 ].item ()
1089+ x2 = rois [i ][3 ].item ()
1090+ y2 = rois [i ][4 ].item ()
1091+ topLeftXY = (x1 , y1 )
1092+ bottomRightXY = (x2 , y2 )
1093+ turing_roi = convertBoxPosToTuringConfig (
1094+ topLeftXY ,
1095+ bottomRightXY ,
1096+ K ,
1097+ output_size_h ,
1098+ output_size_w ,
1099+ sampling_ratio ,
1100+ aligned ,
1101+ )
1102+ turing_rois .append (torch .frombuffer (turing_roi , dtype = torch .uint8 ))
1103+
1104+ out = torch .stack (turing_rois )
1105+ return out
1106+
1107+
1108+ @impl (m , "rms_norm" )
1109+ def rms_norm (
1110+ X : torch .Tensor ,
1111+ normalized_shape : tuple [int ],
1112+ W : torch .Tensor ,
1113+ eps : float ,
1114+ ) -> torch .Tensor :
1115+ return W * nn .RMSNorm (list (normalized_shape ), eps = eps , dtype = X .dtype )(X )
0 commit comments