|
16 | 16 | get_im2row_output_size,
|
17 | 17 | )
|
18 | 18 | from executorch.exir.scalar_type import ScalarType
|
| 19 | +from torch._meta_registrations import _linalg_svd_meta |
19 | 20 | from torch.library import Library, register_fake
|
20 | 21 |
|
21 | 22 | lib = Library("cadence", "DEF")
|
|
250 | 251 | "int in_zero_point, bool channel_last=False) -> (Tensor out)"
|
251 | 252 | )
|
252 | 253 | lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
|
| 254 | +lib.define( |
| 255 | + "linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)" |
| 256 | +) |
| 257 | +lib.define( |
| 258 | + "linalg_svd.out(Tensor A, bool full_matrices=False, bool compute_uv=True, str? driver=None, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)" |
| 259 | +) |
253 | 260 | lib.define(
|
254 | 261 | "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
|
255 | 262 | "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
|
@@ -1576,6 +1583,26 @@ def linalg_vector_norm_meta(
|
1576 | 1583 | return X.new_empty([], dtype=X.dtype)
|
1577 | 1584 |
|
1578 | 1585 |
|
| 1586 | +@register_fake("cadence::linalg_svd") |
| 1587 | +def linalg_svd_meta( |
| 1588 | + A: torch.Tensor, |
| 1589 | + full_matrices: bool = False, |
| 1590 | + compute_uv: bool = True, |
| 1591 | + driver: Optional[str] = None, |
| 1592 | +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 1593 | + # Based on the _linalg_svd meta implementation, but ensuring contiguous strides |
| 1594 | + |
| 1595 | + # Get the shapes from the original meta function |
| 1596 | + U, S, Vh = _linalg_svd_meta(A, full_matrices, compute_uv, driver) |
| 1597 | + |
| 1598 | + # Create new tensors with contiguous strides to fix the non-contiguous issue |
| 1599 | + U_contiguous = A.new_empty(U.shape, dtype=A.dtype).contiguous() |
| 1600 | + S_contiguous = A.new_empty(S.shape, dtype=A.dtype).contiguous() |
| 1601 | + Vh_contiguous = A.new_empty(Vh.shape, dtype=A.dtype).contiguous() |
| 1602 | + |
| 1603 | + return U_contiguous, S_contiguous, Vh_contiguous |
| 1604 | + |
| 1605 | + |
1579 | 1606 | @register_fake("cadence::requantize")
|
1580 | 1607 | def requantize_meta(
|
1581 | 1608 | input: torch.Tensor,
|
|
0 commit comments