Skip to content

Commit 44991cf

Browse files
authored
Merge branch 'main' into fix-macos-runners
2 parents df36f00 + e42c881 commit 44991cf

File tree

6 files changed

+189
-128
lines changed

6 files changed

+189
-128
lines changed

backends/cadence/aot/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,3 +604,17 @@ python_unittest(
604604
"//later:lib",
605605
],
606606
)
607+
608+
python_unittest(
609+
name = "test_ref_implementations",
610+
srcs = [
611+
"tests/test_ref_implementations.py",
612+
],
613+
supports_static_listing = False,
614+
typing = True,
615+
deps = [
616+
":typing_stubs",
617+
"//executorch/backends/cadence/aot:ref_implementations",
618+
"//caffe2:torch",
619+
]
620+
)

backends/cadence/aot/ops_registrations.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_im2row_output_size,
1717
)
1818
from executorch.exir.scalar_type import ScalarType
19+
from torch._meta_registrations import _linalg_svd_meta
1920
from torch.library import Library, register_fake
2021

2122
lib = Library("cadence", "DEF")
@@ -250,6 +251,12 @@
250251
"int in_zero_point, bool channel_last=False) -> (Tensor out)"
251252
)
252253
lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
254+
lib.define(
255+
"linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)"
256+
)
257+
lib.define(
258+
"linalg_svd.out(Tensor A, bool full_matrices=False, bool compute_uv=True, str? driver=None, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)"
259+
)
253260
lib.define(
254261
"transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
255262
"int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
@@ -1576,6 +1583,26 @@ def linalg_vector_norm_meta(
15761583
return X.new_empty([], dtype=X.dtype)
15771584

15781585

1586+
@register_fake("cadence::linalg_svd")
1587+
def linalg_svd_meta(
1588+
A: torch.Tensor,
1589+
full_matrices: bool = False,
1590+
compute_uv: bool = True,
1591+
driver: Optional[str] = None,
1592+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1593+
# Based on the _linalg_svd meta implementation, but ensuring contiguous strides
1594+
1595+
# Get the shapes from the original meta function
1596+
U, S, Vh = _linalg_svd_meta(A, full_matrices, compute_uv, driver)
1597+
1598+
# Create new tensors with contiguous strides to fix the non-contiguous issue
1599+
U_contiguous = A.new_empty(U.shape, dtype=A.dtype).contiguous()
1600+
S_contiguous = A.new_empty(S.shape, dtype=A.dtype).contiguous()
1601+
Vh_contiguous = A.new_empty(Vh.shape, dtype=A.dtype).contiguous()
1602+
1603+
return U_contiguous, S_contiguous, Vh_contiguous
1604+
1605+
15791606
@register_fake("cadence::requantize")
15801607
def requantize_meta(
15811608
input: torch.Tensor,

backends/cadence/aot/ref_implementations.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,42 @@
2020
}
2121

2222

23+
@impl(m, "quantize_per_tensor")
24+
def quantize_per_tensor(
25+
input: torch.Tensor,
26+
scale: float,
27+
zero_point: int,
28+
quant_min: int,
29+
quant_max: int,
30+
dtype: torch.dtype,
31+
) -> torch.Tensor:
32+
"""
33+
Quantizes a floating-point tensor to an integral tensor.
34+
35+
Args:
36+
- input (Tensor): input tensor
37+
- scale (float): Quantization scale. Derived from the ratio
38+
between the min/max of the floating-point tensor and the
39+
min/max of the quantized range.
40+
- zero_point (int): The point which represents 0 in the quantized
41+
range. For example, consider the floating point range [-1., 2.] and
42+
quantized integer range [-7, 7]. In this case, 0 is 1/3 of way from
43+
-1. to 2. So, the point that represents 0 in the quantized range should
44+
be 1/3 of the way from [-7, 7]. This ends up being -2 in the integer space.
45+
- quant_min (int): The smallest value in the quantized domain. Unused since scale
46+
is already provided.
47+
- quant_max (int): The largest value in the quantized domain. Unused since scale
48+
is already provided.
49+
- dtype (torch.dtype): The type of the output tensor
50+
"""
51+
supported_quant_types = [torch.int8, torch.int16, torch.int32]
52+
if dtype not in supported_quant_types:
53+
raise ValueError(
54+
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
55+
)
56+
return torch.round(input / scale + zero_point).to(dtype)
57+
58+
2359
@impl(m, "requantize")
2460
def requantize(
2561
input: torch.Tensor,
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import torch
12+
13+
from executorch.backends.cadence.aot.ref_implementations import quantize_per_tensor
14+
from executorch.backends.cadence.aot.typing_stubs import expand
15+
16+
17+
class TestRefImplementations(unittest.TestCase):
18+
@expand(
19+
[
20+
("basic_int8", 0.42, -1.0, 2.0, -7, 7, torch.int8, 0),
21+
("basic_int16", 0.42, -1.0, 5.0, -6, 7, torch.int16, -3),
22+
]
23+
)
24+
def test_quantize_per_tensor(
25+
self,
26+
name: str,
27+
input_value: float,
28+
f_min: float,
29+
f_max: float,
30+
q_min: int,
31+
q_max: int,
32+
target_dtype: torch.dtype,
33+
expected_value: int,
34+
) -> None:
35+
input_tensor = torch.tensor([input_value])
36+
scale = (f_max - f_min) / (q_max - q_min)
37+
zero_point = round(-f_min / scale) + q_min
38+
expected_output = torch.tensor([expected_value], dtype=target_dtype)
39+
40+
output = quantize_per_tensor(
41+
input_tensor, scale, zero_point, q_min, q_max, target_dtype
42+
)
43+
44+
self.assertEqual(
45+
output.dtype, expected_output.dtype, f"Dtype mismatch in {name}"
46+
)
47+
self.assertTrue(
48+
torch.equal(output, expected_output),
49+
f"Values don't match in {name}: got {output}, expected {expected_output}",
50+
)

0 commit comments

Comments
 (0)