|
11 | 11 |
|
12 | 12 | # pyre-unsafe |
13 | 13 |
|
14 | | -import copy |
15 | 14 | import math |
16 | 15 | from operator import neg |
17 | 16 | from typing import cast, Dict, Iterable, Sequence, Set, Tuple |
|
36 | 35 | from executorch.backends.cadence.aot.utils import get_edge_overload_packet |
37 | 36 | from executorch.exir.dialects._ops import ops as exir_ops |
38 | 37 | from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket |
39 | | -from executorch.exir.dim_order_utils import get_memory_format |
40 | 38 | from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue |
41 | | -from executorch.exir.passes.dim_order_ops_registry import ( |
42 | | - DimOrderOpsMap, |
43 | | - MemoryFormatOpsMap, |
44 | | -) |
45 | 39 | from torch._subclasses import FakeTensor |
46 | 40 | from torch.fx.node import Argument |
47 | 41 |
|
@@ -1805,72 +1799,6 @@ def call_operator( |
1805 | 1799 | ) |
1806 | 1800 |
|
1807 | 1801 |
|
1808 | | -@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1809 | | -class ReplaceToDimOrderCopyWithToCopyPass(ExportPass): |
1810 | | - """ |
1811 | | - dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass. |
1812 | | - If the dim order is sequential, we don't need the extra work with strides and |
1813 | | - can just use to_copy. |
1814 | | - """ |
1815 | | - |
1816 | | - def call_operator( |
1817 | | - self, |
1818 | | - op, |
1819 | | - args: Tuple[Argument, ...], |
1820 | | - kwargs: Dict[str, Argument], |
1821 | | - meta: NodeMetadata, |
1822 | | - ) -> ProxyValue: |
1823 | | - if op not in DimOrderOpsMap: |
1824 | | - return super().call_operator(op, args, kwargs, meta) |
1825 | | - |
1826 | | - # new kwargs with dim_order, and no memory_format for the new op |
1827 | | - nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable |
1828 | | - |
1829 | | - ndim = None |
1830 | | - |
1831 | | - # can always get the shape, assuming rank is specialized |
1832 | | - |
1833 | | - # pyre-ignore[16]: `None` has no attribute `to_tensor` |
1834 | | - if isinstance(args[0], ProxyValue) and args[0].is_tensor(): |
1835 | | - # pyre-ignore[16]: `None` has no attribute `to_tensor` |
1836 | | - ndim = args[0].to_tensor().dim() |
1837 | | - elif isinstance(args[0], torch.Tensor): |
1838 | | - # pyre-ignore[16]: `None` has no attribute `dim` |
1839 | | - ndim = args[0].dim() |
1840 | | - elif isinstance(args[0], torch.fx.immutable_collections.immutable_list): |
1841 | | - # pyre-ignore[6]: Incompatible parameter type |
1842 | | - ndim = len(args[0]) |
1843 | | - else: |
1844 | | - assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}" |
1845 | | - |
1846 | | - # get the "to" memory format for the EdgeOp |
1847 | | - contiguous_dim_order = list(range(ndim)) |
1848 | | - dim_order = nkwargs.pop("dim_order", None) |
1849 | | - |
1850 | | - # Cadence only supports contiguous memory format |
1851 | | - assert ( |
1852 | | - dim_order is None |
1853 | | - # pyre-ignore[6]: Incompatible parameter type |
1854 | | - or len(dim_order) == 0 |
1855 | | - or dim_order == contiguous_dim_order |
1856 | | - ), "Expected dim order in congituous or prevserve memory format, but got {}".format( |
1857 | | - dim_order |
1858 | | - ) |
1859 | | - |
1860 | | - # bring back memory format |
1861 | | - # pyre-ignore[6]: Incompatible parameter type |
1862 | | - nkwargs["memory_format"] = get_memory_format(dim_order) |
1863 | | - |
1864 | | - memory_format_op = MemoryFormatOpsMap[op] |
1865 | | - |
1866 | | - return super().call_operator( |
1867 | | - memory_format_op, |
1868 | | - args, |
1869 | | - nkwargs, |
1870 | | - meta, |
1871 | | - ) |
1872 | | - |
1873 | | - |
1874 | 1802 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1875 | 1803 | class ReplaceFullLikeWithFullPass(ExportPass): |
1876 | 1804 | """ |
@@ -2180,5 +2108,4 @@ class CadenceReplaceOpsInGraph: |
2180 | 2108 | ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, |
2181 | 2109 | ReplaceAtenAvgPoolWithJarvisAvgPoolPass, |
2182 | 2110 | ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, |
2183 | | - ReplaceToDimOrderCopyWithToCopyPass, |
2184 | 2111 | ] |
0 commit comments