Skip to content

Commit ed179c0

Browse files
authored
Ref implementations interface fixes
Differential Revision: D82566217 Pull Request resolved: #14357
1 parent 5348ea9 commit ed179c0

File tree

3 files changed

+102
-52
lines changed

3 files changed

+102
-52
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ runtime.python_library(
130130
deps = [
131131
"fbcode//caffe2:torch",
132132
"fbcode//executorch/exir:scalar_type",
133+
"fbcode//executorch/kernels/quantized:custom_ops_generated_lib",
133134
],
134135
)
135136

backends/cadence/aot/ref_implementations.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66

77
# pyre-strict
88

9-
109
from typing import Callable
1110

1211
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
1314

1415
from executorch.exir.scalar_type import ScalarType
1516
from torch.library import impl, Library
1617

17-
1818
m = Library("cadence", "IMPL", "CompositeExplicitAutograd")
19+
torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib")
1920

2021
qdtype_map: dict[ScalarType, torch.dtype] = {
2122
ScalarType.QINT8: torch.qint8,
@@ -38,7 +39,7 @@ def quantize_per_tensor(
3839
3940
Args:
4041
- input_tensor (Tensor): input tensor
41-
- scale (float): Inverse of quantization scale. Derived from the ratio
42+
- scale (float): Quantization scale. Derived from the ratio
4243
between the min/max of the floating-point tensor and the
4344
min/max of the quantized range, and then inverted.
4445
- zero_point (int): The point which represents 0 in the quantized
@@ -64,10 +65,13 @@ def quantize_per_tensor(
6465
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
6566
)
6667

67-
quantized = torch.round(input_tensor * scale + zero_point).to(dtype)
68-
return torch.max(
69-
torch.min(quantized, torch.tensor(quant_max)),
70-
torch.tensor(quant_min),
68+
return torch.ops.quantized_decomposed.quantize_per_tensor(
69+
input_tensor,
70+
scale,
71+
zero_point,
72+
quant_min,
73+
quant_max,
74+
dtype,
7175
)
7276

7377

@@ -97,7 +101,7 @@ def dequantize_per_tensor(
97101
is already provided.
98102
- quant_max (int): The largest value in the quantized domain. Unused since scale
99103
is already provided.
100-
- dtype (torch.dtype): The type of the output tensor. Must be a floating point type.
104+
- dtype (torch.dtype): The type of the input tensor.
101105
"""
102106
supported_quant_types = [
103107
torch.int8,
@@ -108,23 +112,15 @@ def dequantize_per_tensor(
108112
]
109113
if input_tensor.dtype not in supported_quant_types:
110114
raise ValueError(f"Input dtype must be one of {supported_quant_types}")
111-
supported_dequant_types = [
112-
torch.float,
113-
torch.float32,
114-
torch.float16,
115-
torch.bfloat16,
116-
]
117-
if dtype not in supported_dequant_types:
118-
raise ValueError(
119-
f"Unsupported dtype to dequantize to. Supported dtypes must be one of {supported_dequant_types}"
120-
)
121-
122-
# Needed to prevent underflow in cases where the zero_point is larger than
123-
# the quantized value.
124-
if not input_tensor.dtype.is_signed:
125-
input_tensor = input_tensor.to(torch.int32)
126-
127-
return (input_tensor - zero_point).to(dtype) * scale
115+
if input_tensor.dtype != dtype:
116+
raise ValueError("Input dtype must match dtype")
117+
118+
# Use the reference implementation from torch quantized_decomposed library
119+
# Unlike quantize_per_tensor, dequantize_per_tensor doesn't have a behavior
120+
# difference, since there's no rounding algorithm (just arithmetic).
121+
return torch.ops.quantized_decomposed.dequantize_per_tensor(
122+
input_tensor, scale, zero_point, quant_min, quant_max, dtype
123+
)
128124

129125

130126
@impl(m, "quantized_add.per_tensor")
@@ -180,12 +176,10 @@ def quantized_add_per_tensor(
180176
dequant_X = X_scale * (X - X_zero_point)
181177
dequant_Y = Y_scale * (Y - Y_zero_point)
182178

183-
out_scale_inv = 1 / out_scale
184-
185179
# q_min/q_max are unused args
186180
return quantize_per_tensor(
187181
dequant_X + dequant_Y,
188-
out_scale_inv,
182+
out_scale,
189183
out_zero_point,
190184
torch.iinfo(dtype).min,
191185
torch.iinfo(dtype).max,
@@ -259,8 +253,7 @@ def quantized_linear_common(
259253
- out_zero_point (int): The quantized mapping of zero for the output
260254
- offset (Tensor): Unused
261255
"""
262-
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
263-
out_scale_inv = 1 / out_scale
256+
out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift))
264257

265258
N, K = weight.shape
266259

@@ -281,7 +274,7 @@ def quantized_linear_common(
281274
)
282275
return quantize_per_tensor(
283276
out,
284-
out_scale_inv,
277+
out_scale,
285278
out_zero_point,
286279
torch.iinfo(dtype).min,
287280
torch.iinfo(dtype).max,
@@ -399,6 +392,17 @@ def quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor() -> torch.Tensor:
399392
def quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor() -> torch.Tensor: ...
400393

401394

395+
@impl(m, "fully_connected")
396+
def fully_connected(
397+
input_tensor: torch.Tensor,
398+
weight: torch.Tensor,
399+
bias: torch.Tensor,
400+
) -> torch.Tensor:
401+
if input_tensor.shape[0] != 1:
402+
raise ValueError("Fully connected linear only supports batch size of 1")
403+
return F.linear(input_tensor, weight, bias)
404+
405+
402406
@impl(m, "quantized_matmul")
403407
def quantized_matmul(
404408
X: torch.Tensor,
@@ -538,15 +542,15 @@ def quantized_layer_norm_per_tensor(
538542
)
539543

540544
float_input_tensor = dequantize_per_tensor(
541-
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
545+
input_tensor, X_scale, X_zero_point, -128, 127, input_tensor.dtype
542546
)
543547
out = torch.nn.functional.layer_norm(
544548
float_input_tensor, normalized_shape, weight, bias, eps=eps
545549
)
546550

547551
return quantize_per_tensor(
548552
out,
549-
1 / output_scale,
553+
output_scale,
550554
output_zero_point,
551555
torch.iinfo(input_tensor.dtype).min,
552556
torch.iinfo(input_tensor.dtype).max,
@@ -615,7 +619,7 @@ def quantized_conv_per_tensor(
615619

616620
return quantize_per_tensor(
617621
float_out,
618-
1.0 / output_scale,
622+
output_scale,
619623
output_zero_point,
620624
torch.iinfo(input_tensor.dtype).min,
621625
torch.iinfo(input_tensor.dtype).max,
@@ -950,8 +954,10 @@ def quantized_relu_common(
950954
if X.dtype not in supported_dtypes:
951955
raise ValueError(f"X dtype must be one of {supported_dtypes}. Got {X.dtype}")
952956

953-
out_scale = -out_multiplier * (1 / (1 << 31)) * (2**out_shift)
954-
dequantized_X = torch.where(X > X_zero_point, X - X_zero_point, torch.zeros_like(X))
957+
out_scale = 1.0 / (-out_multiplier * (1 / (1 << 31)) * (2**out_shift))
958+
dequantized_X = torch.where(
959+
X > X_zero_point, X - X_zero_point, torch.zeros_like(X)
960+
).to(torch.float32)
955961
return quantize_per_tensor(
956962
dequantized_X,
957963
out_scale,
@@ -1076,3 +1082,13 @@ def requantize(
10761082
out_quant_max,
10771083
dtype,
10781084
)
1085+
1086+
1087+
@impl(m, "rms_norm")
1088+
def rms_norm(
1089+
X: torch.Tensor,
1090+
normalized_shape: tuple[int],
1091+
W: torch.Tensor,
1092+
eps: float,
1093+
) -> torch.Tensor:
1094+
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@ def test_quantize_per_tensor(
3636
) -> None:
3737
input_tensor = torch.tensor([input_value])
3838
scale = (f_max - f_min) / (q_max - q_min)
39-
inv_scale = 1.0 / scale
40-
zero_point = round(-f_min * inv_scale) + q_min
39+
zero_point = round(-f_min * 1 / scale) + q_min
4140
expected_output = torch.tensor([expected_value], dtype=target_dtype)
4241

4342
output = torch.ops.cadence.quantize_per_tensor(
44-
input_tensor, inv_scale, zero_point, q_min, q_max, target_dtype
43+
input_tensor, scale, zero_point, q_min, q_max, target_dtype
4544
)
4645

4746
self.assertEqual(
@@ -85,7 +84,7 @@ def test_dequantize_per_tensor(
8584
expected_output = torch.tensor([expected_value], dtype=torch.float32)
8685

8786
output = torch.ops.cadence.dequantize_per_tensor(
88-
input_tensor, scale, zero_point, q_min, q_max, torch.float32
87+
input_tensor, scale, zero_point, q_min, q_max, input_tensor.dtype
8988
)
9089

9190
self.assertEqual(
@@ -175,7 +174,7 @@ def test_quantized_add(
175174
), # out_multiplier (0.5 * 2^31)
176175
torch.tensor([0], dtype=torch.int64), # out_shift
177176
0, # out_zero_point
178-
torch.tensor([[-2]], dtype=dtype), # expected_output
177+
torch.tensor([[0]], dtype=dtype), # expected_output
179178
per_tensor,
180179
False,
181180
False,
@@ -200,14 +199,36 @@ def test_quantized_add(
200199
), # out_multiplier (0.5 * 2^31)
201200
torch.tensor([0], dtype=torch.int64), # out_shift
202201
0, # out_zero_point
203-
torch.tensor([[-10, -30]], dtype=dtype), # expected_output
202+
torch.tensor([[-2, -8]], dtype=dtype), # expected_output
204203
per_tensor,
205204
False,
206205
False,
207206
)
208207
for (per_tensor, dtype) in (
209208
(False, torch.int8),
210209
(True, torch.int8),
210+
)
211+
],
212+
*[
213+
(
214+
torch.Size([1, 3]), # src_shape: 1 sample, 3 input features
215+
torch.Size(
216+
[2, 3]
217+
), # weight_shape: 2 output features, 3 input features
218+
0, # in_zero_point
219+
torch.tensor([0, 0, 0], dtype=dtype), # weight_zero_point
220+
torch.tensor(
221+
[1073741824], dtype=torch.int32
222+
), # out_multiplier (0.5 * 2^31)
223+
torch.tensor([0], dtype=torch.int64), # out_shift
224+
0, # out_zero_point
225+
torch.tensor([[0, 0]], dtype=dtype), # expected_output
226+
per_tensor,
227+
False,
228+
False,
229+
)
230+
for (per_tensor, dtype) in (
231+
(False, torch.uint8),
211232
(True, torch.uint8),
212233
)
213234
],
@@ -226,7 +247,7 @@ def test_quantized_add(
226247
torch.tensor([0], dtype=torch.int64), # out_shift
227248
0, # out_zero_point
228249
torch.tensor(
229-
[[[-2, -8, -14], [-6, -28, -50]]], dtype=dtype
250+
[[[0, -2, -4], [-2, -7, -12]]], dtype=dtype
230251
), # expected_output
231252
per_tensor,
232253
False,
@@ -235,7 +256,6 @@ def test_quantized_add(
235256
for (per_tensor, dtype) in (
236257
(False, torch.int8),
237258
(True, torch.int8),
238-
(True, torch.uint8),
239259
)
240260
],
241261
# Test case 4: Non-zero zero points
@@ -252,15 +272,15 @@ def test_quantized_add(
252272
), # out_multiplier (1.0 * 2^31)
253273
torch.tensor([0], dtype=torch.int64), # out_shift
254274
1, # out_zero_point
255-
torch.tensor([[-15, 25]], dtype=dtype), # expected_output
275+
torch.tensor([[1, 1]], dtype=dtype), # expected_output
256276
per_tensor,
257277
False,
258278
False,
259279
)
260280
for (per_tensor, dtype) in (
261281
(False, torch.int8),
262282
(True, torch.int8),
263-
(True, torch.uint8),
283+
# (True, torch.uint8),
264284
)
265285
],
266286
# Test case 5: Non-uniform weight zero points
@@ -277,12 +297,12 @@ def test_quantized_add(
277297
), # out_multiplier (1.0 * 2^31)
278298
torch.tensor([0], dtype=torch.int64), # out_shift
279299
1, # out_zero_point
280-
torch.tensor([[-23, 17]], dtype=dtype), # expected_output
300+
torch.tensor([[1, 1]], dtype=dtype), # expected_output
281301
False,
282302
False,
283303
False,
284304
)
285-
for dtype in (torch.int8, torch.uint8)
305+
for dtype in (torch.int8,)
286306
],
287307
# Test case 6: Non-zero out_shift (shift=1)
288308
*[
@@ -300,7 +320,7 @@ def test_quantized_add(
300320
[1], dtype=torch.int64
301321
), # out_shift (shift=1, doubles the scale)
302322
1, # out_zero_point
303-
torch.tensor([[-7, 13]], dtype=dtype), # expected_output
323+
torch.tensor([[1, 2]], dtype=dtype), # expected_output
304324
per_tensor,
305325
False,
306326
False,
@@ -322,13 +342,13 @@ def test_quantized_add(
322342
[1], dtype=torch.int64
323343
), # out_shift (shift=1, doubles the scale)
324344
1, # out_zero_point
325-
torch.tensor([[-7, 17]], dtype=dtype), # expected_output
345+
torch.tensor([[1, 2]], dtype=dtype), # expected_output
326346
per_tensor,
327347
matmul,
328348
transposed_matmul,
329349
)
330350
for (matmul, transposed_matmul) in ((True, False), (True, True))
331-
for (per_tensor, dtype) in ((True, torch.int8), (True, torch.uint8))
351+
for (per_tensor, dtype) in ((True, torch.int8),)
332352
],
333353
]
334354
)
@@ -1045,7 +1065,20 @@ def test_quantized_conv_per_tensor(
10451065
[4, 2, 0, -2], dtype=dtype
10461066
), # expected: relu(1,3,5,7) = (1,3,5,7) * (-1.0) + 5 = (4,2,0,-2)
10471067
)
1048-
for dtype in [torch.int8, torch.uint8]
1068+
for dtype in [torch.int8]
1069+
],
1070+
*[
1071+
(
1072+
"positive_with_shift_unsigned",
1073+
torch.tensor([2, 4, 6, 8], dtype=dtype), # input
1074+
1, # X_zero_point
1075+
5, # out_zero_point
1076+
1073741824, # out_multiplier (0.5 * 2^31)
1077+
1, # out_shift (multiply by 2^1 = 2)
1078+
dtype, # dtype
1079+
torch.tensor([4, 2, 0, 0], dtype=dtype),
1080+
)
1081+
for dtype in [torch.uint8]
10491082
],
10501083
# Test case 4: Non-per-tensor
10511084
*[

0 commit comments

Comments
 (0)