Skip to content

Commit fb8a6a0

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Migrate uncommon cadence custom ops from Jarvis.nn.ref_implementations -> executorch ref_implementations
Summary: It turns out there was duplication in the cadence custom op ref implementation files, which could lead to op name registry collision (op name was already registered). Resolved by migrating uncommon ops from Jarvis.nn.ref_implementations to the executorch ref_implementations, deleting the Jarvis file, and updating all of the dependencies. Reviewed By: mcremon-meta Differential Revision: D82566217
1 parent 108d29d commit fb8a6a0

File tree

3 files changed

+75
-29
lines changed

3 files changed

+75
-29
lines changed

backends/cadence/aot/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ runtime.python_library(
130130
deps = [
131131
"fbcode//caffe2:torch",
132132
"fbcode//executorch/exir:scalar_type",
133+
"fbcode//on_device_ai/Assistant/Jarvis/nn:roi_align_utils",
134+
"fbcode//executorch/kernels/quantized:custom_ops_generated_lib",
133135
],
134136
)
135137

backends/cadence/aot/ref_implementations.py

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66

77
# pyre-strict
88

9-
109
from typing import Callable
1110

1211
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
1314

1415
from executorch.exir.scalar_type import ScalarType
16+
from on_device_ai.Assistant.Jarvis.nn.roi_align_utils import convertBoxPosToTuringConfig
1517
from torch.library import impl, Library
1618

17-
1819
m = Library("cadence", "IMPL", "CompositeExplicitAutograd")
20+
torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib")
1921

2022
qdtype_map: dict[ScalarType, torch.dtype] = {
2123
ScalarType.QINT8: torch.qint8,
@@ -38,7 +40,7 @@ def quantize_per_tensor(
3840
3941
Args:
4042
- input_tensor (Tensor): input tensor
41-
- scale (float): Inverse of quantization scale. Derived from the ratio
43+
- scale (float): Quantization scale. Derived from the ratio
4244
between the min/max of the floating-point tensor and the
4345
min/max of the quantized range, and then inverted.
4446
- zero_point (int): The point which represents 0 in the quantized
@@ -64,7 +66,8 @@ def quantize_per_tensor(
6466
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
6567
)
6668

67-
quantized = torch.round(input_tensor * scale + zero_point).to(dtype)
69+
inv_scale = 1.0 / scale
70+
quantized = torch.round(input_tensor * inv_scale + zero_point).to(dtype)
6871
return torch.max(
6972
torch.min(quantized, torch.tensor(quant_max)),
7073
torch.tensor(quant_min),
@@ -97,7 +100,7 @@ def dequantize_per_tensor(
97100
is already provided.
98101
- quant_max (int): The largest value in the quantized domain. Unused since scale
99102
is already provided.
100-
- dtype (torch.dtype): The type of the output tensor. Must be a floating point type.
103+
- dtype (torch.dtype): The type of the input tensor.
101104
"""
102105
supported_quant_types = [
103106
torch.int8,
@@ -108,23 +111,15 @@ def dequantize_per_tensor(
108111
]
109112
if input_tensor.dtype not in supported_quant_types:
110113
raise ValueError(f"Input dtype must be one of {supported_quant_types}")
111-
supported_dequant_types = [
112-
torch.float,
113-
torch.float32,
114-
torch.float16,
115-
torch.bfloat16,
116-
]
117-
if dtype not in supported_dequant_types:
118-
raise ValueError(
119-
f"Unsupported dtype to dequantize to. Supported dtypes must be one of {supported_dequant_types}"
120-
)
114+
if input_tensor.dtype != dtype:
115+
raise ValueError("Input dtype must match dtype")
121116

122117
# Needed to prevent underflow in cases where the zero_point is larger than
123118
# the quantized value.
124119
if not input_tensor.dtype.is_signed:
125120
input_tensor = input_tensor.to(torch.int32)
126121

127-
return (input_tensor - zero_point).to(dtype) * scale
122+
return ((input_tensor - zero_point) * scale).to(torch.float32)
128123

129124

130125
@impl(m, "quantized_add.per_tensor")
@@ -180,12 +175,10 @@ def quantized_add_per_tensor(
180175
dequant_X = X_scale * (X - X_zero_point)
181176
dequant_Y = Y_scale * (Y - Y_zero_point)
182177

183-
out_scale_inv = 1 / out_scale
184-
185178
# q_min/q_max are unused args
186179
return quantize_per_tensor(
187180
dequant_X + dequant_Y,
188-
out_scale_inv,
181+
out_scale,
189182
out_zero_point,
190183
torch.iinfo(dtype).min,
191184
torch.iinfo(dtype).max,
@@ -260,7 +253,6 @@ def quantized_linear_common(
260253
- offset (Tensor): Unused
261254
"""
262255
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
263-
out_scale_inv = 1 / out_scale
264256

265257
N, K = weight.shape
266258

@@ -281,7 +273,7 @@ def quantized_linear_common(
281273
)
282274
return quantize_per_tensor(
283275
out,
284-
out_scale_inv,
276+
out_scale,
285277
out_zero_point,
286278
torch.iinfo(dtype).min,
287279
torch.iinfo(dtype).max,
@@ -399,6 +391,17 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor:
399391
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
400392

401393

394+
@impl(m, "fully_connected")
395+
def fully_connected(
396+
input_tensor: torch.Tensor,
397+
weight: torch.Tensor,
398+
bias: torch.Tensor,
399+
) -> torch.Tensor:
400+
if input_tensor.shape[0] != 1:
401+
raise ValueError("Fully connected linear only supports batch size of 1")
402+
return F.linear(input_tensor, weight, bias)
403+
404+
402405
@impl(m, "quantized_matmul")
403406
def quantized_matmul(
404407
X: torch.Tensor,
@@ -538,15 +541,15 @@ def quantized_layer_norm_per_tensor(
538541
)
539542

540543
float_input_tensor = dequantize_per_tensor(
541-
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
544+
input_tensor, X_scale, X_zero_point, -128, 127, input_tensor.dtype
542545
)
543546
out = torch.nn.functional.layer_norm(
544547
float_input_tensor, normalized_shape, weight, bias, eps=eps
545548
)
546549

547550
return quantize_per_tensor(
548551
out,
549-
1 / output_scale,
552+
output_scale,
550553
output_zero_point,
551554
torch.iinfo(input_tensor.dtype).min,
552555
torch.iinfo(input_tensor.dtype).max,
@@ -615,7 +618,7 @@ def quantized_conv_per_tensor(
615618

616619
return quantize_per_tensor(
617620
float_out,
618-
1.0 / output_scale,
621+
output_scale,
619622
output_zero_point,
620623
torch.iinfo(input_tensor.dtype).min,
621624
torch.iinfo(input_tensor.dtype).max,
@@ -942,7 +945,7 @@ def quantized_relu_common(
942945
if X.dtype not in supported_dtypes:
943946
raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}")
944947

945-
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
948+
out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift))
946949
dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X))
947950
return quantize_per_tensor(
948951
dequantized_X,
@@ -1068,3 +1071,45 @@ def requantize(
10681071
out_quant_max,
10691072
dtype,
10701073
)
1074+
1075+
1076+
@impl(m, "roi_align_box_processor")
1077+
def roi_align_box_processor(
1078+
rois: torch.Tensor,
1079+
output_size_h: int,
1080+
output_size_w: int,
1081+
sampling_ratio: int,
1082+
aligned: bool,
1083+
) -> torch.Tensor:
1084+
K = rois.shape[0]
1085+
turing_rois = []
1086+
for i in range(K):
1087+
x1 = rois[i][1].item()
1088+
y1 = rois[i][2].item()
1089+
x2 = rois[i][3].item()
1090+
y2 = rois[i][4].item()
1091+
topLeftXY = (x1, y1)
1092+
bottomRightXY = (x2, y2)
1093+
turing_roi = convertBoxPosToTuringConfig(
1094+
topLeftXY,
1095+
bottomRightXY,
1096+
K,
1097+
output_size_h,
1098+
output_size_w,
1099+
sampling_ratio,
1100+
aligned,
1101+
)
1102+
turing_rois.append(torch.frombuffer(turing_roi, dtype=torch.uint8))
1103+
1104+
out = torch.stack(turing_rois)
1105+
return out
1106+
1107+
1108+
@impl(m, "rms_norm")
1109+
def rms_norm(
1110+
X: torch.Tensor,
1111+
normalized_shape: tuple[int],
1112+
W: torch.Tensor,
1113+
eps: float,
1114+
) -> torch.Tensor:
1115+
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@ def test_quantize_per_tensor(
3636
) -> None:
3737
input_tensor = torch.tensor([input_value])
3838
scale = (f_max - f_min) / (q_max - q_min)
39-
inv_scale = 1.0 / scale
40-
zero_point = round(-f_min * inv_scale) + q_min
39+
zero_point = round(-f_min * 1 / scale) + q_min
4140
expected_output = torch.tensor([expected_value], dtype=target_dtype)
4241

4342
output = torch.ops.cadence.quantize_per_tensor(
44-
input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype
43+
input_tensor, scale, zero_point, q_min, q_max, target_dtype
4544
)
4645

4746
self.assertEqual(
@@ -85,7 +84,7 @@ def test_dequantize_per_tensor(
8584
expected_output = torch.tensor([expected_value], dtype=torch.float32)
8685

8786
output = torch.ops.cadence.dequantize_per_tensor(
88-
input_tensor, scale, zero_point, q_min, q_max, torch.float32
87+
input_tensor, scale, zero_point, q_min, q_max, input_tensor.dtype
8988
)
9089

9190
self.assertEqual(

0 commit comments

Comments
 (0)