Skip to content

Commit 3e69a64

Browse files
authored
Add custom cadence::linalg_svd operation with contiguous strides
Differential Revision: D80903940 Pull Request resolved: #13718
1 parent 15b9c7d commit 3e69a64

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_im2row_output_size,
1717
)
1818
from executorch.exir.scalar_type import ScalarType
19+
from torch._meta_registrations import _linalg_svd_meta
1920
from torch.library import Library, register_fake
2021

2122
lib = Library("cadence", "DEF")
@@ -250,6 +251,12 @@
250251
"int in_zero_point, bool channel_last=False) -> (Tensor out)"
251252
)
252253
lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
254+
lib.define(
255+
"linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)"
256+
)
257+
lib.define(
258+
"linalg_svd.out(Tensor A, bool full_matrices=False, bool compute_uv=True, str? driver=None, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)"
259+
)
253260
lib.define(
254261
"transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
255262
"int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
@@ -1576,6 +1583,26 @@ def linalg_vector_norm_meta(
15761583
return X.new_empty([], dtype=X.dtype)
15771584

15781585

1586+
@register_fake("cadence::linalg_svd")
1587+
def linalg_svd_meta(
1588+
A: torch.Tensor,
1589+
full_matrices: bool = False,
1590+
compute_uv: bool = True,
1591+
driver: Optional[str] = None,
1592+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1593+
# Based on the _linalg_svd meta implementation, but ensuring contiguous strides
1594+
1595+
# Get the shapes from the original meta function
1596+
U, S, Vh = _linalg_svd_meta(A, full_matrices, compute_uv, driver)
1597+
1598+
# Create new tensors with contiguous strides to fix the non-contiguous issue
1599+
U_contiguous = A.new_empty(U.shape, dtype=A.dtype).contiguous()
1600+
S_contiguous = A.new_empty(S.shape, dtype=A.dtype).contiguous()
1601+
Vh_contiguous = A.new_empty(Vh.shape, dtype=A.dtype).contiguous()
1602+
1603+
return U_contiguous, S_contiguous, Vh_contiguous
1604+
1605+
15791606
@register_fake("cadence::requantize")
15801607
def requantize_meta(
15811608
input: torch.Tensor,

0 commit comments

Comments
 (0)