| 
11 | 11 | 
 
  | 
12 | 12 | # pyre-unsafe  | 
13 | 13 | 
 
  | 
 | 14 | +import copy  | 
14 | 15 | import math  | 
15 | 16 | from operator import neg  | 
16 | 17 | from typing import cast, Dict, Iterable, Sequence, Set, Tuple  | 
 | 
35 | 36 | from executorch.backends.cadence.aot.utils import get_edge_overload_packet  | 
36 | 37 | from executorch.exir.dialects._ops import ops as exir_ops  | 
37 | 38 | from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket  | 
 | 39 | +from executorch.exir.dim_order_utils import get_memory_format  | 
38 | 40 | from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue  | 
 | 41 | +from executorch.exir.passes.dim_order_ops_registry import (  | 
 | 42 | +    DimOrderOpsMap,  | 
 | 43 | +    MemoryFormatOpsMap,  | 
 | 44 | +)  | 
39 | 45 | from torch._subclasses import FakeTensor  | 
40 | 46 | from torch.fx.node import Argument  | 
41 | 47 | 
 
  | 
@@ -1799,6 +1805,72 @@ def call_operator(  | 
1799 | 1805 |         )  | 
1800 | 1806 | 
 
  | 
1801 | 1807 | 
 
  | 
 | 1808 | +@register_cadence_pass(CadencePassAttribute(opt_level=0))  | 
 | 1809 | +class ReplaceToDimOrderCopyWithToCopyPass(ExportPass):  | 
 | 1810 | +    """  | 
 | 1811 | +    dim_order_ops::to_dim_order_copy is not supported, so this is an opt_level=0 pass.  | 
 | 1812 | +    If the dim order is sequential, we don't need the extra work with strides and  | 
 | 1813 | +    can just use to_copy.  | 
 | 1814 | +    """  | 
 | 1815 | + | 
 | 1816 | +    def call_operator(  | 
 | 1817 | +        self,  | 
 | 1818 | +        op,  | 
 | 1819 | +        args: Tuple[Argument, ...],  | 
 | 1820 | +        kwargs: Dict[str, Argument],  | 
 | 1821 | +        meta: NodeMetadata,  | 
 | 1822 | +    ) -> ProxyValue:  | 
 | 1823 | +        if op not in DimOrderOpsMap:  | 
 | 1824 | +            return super().call_operator(op, args, kwargs, meta)  | 
 | 1825 | + | 
 | 1826 | +        # new kwargs with dim_order, and no memory_format for the new op  | 
 | 1827 | +        nkwargs = dict(copy.deepcopy(kwargs))  # orig kwargs are immutable  | 
 | 1828 | + | 
 | 1829 | +        ndim = None  | 
 | 1830 | + | 
 | 1831 | +        # can always get the shape, assuming rank is specialized  | 
 | 1832 | + | 
 | 1833 | +        # pyre-ignore[16]: `None` has no attribute `to_tensor`  | 
 | 1834 | +        if isinstance(args[0], ProxyValue) and args[0].is_tensor():  | 
 | 1835 | +            # pyre-ignore[16]: `None` has no attribute `to_tensor`  | 
 | 1836 | +            ndim = args[0].to_tensor().dim()  | 
 | 1837 | +        elif isinstance(args[0], torch.Tensor):  | 
 | 1838 | +            # pyre-ignore[16]: `None` has no attribute `dim`  | 
 | 1839 | +            ndim = args[0].dim()  | 
 | 1840 | +        elif isinstance(args[0], torch.fx.immutable_collections.immutable_list):  | 
 | 1841 | +            # pyre-ignore[6]: Incompatible parameter type  | 
 | 1842 | +            ndim = len(args[0])  | 
 | 1843 | +        else:  | 
 | 1844 | +            assert 0, f"Expecting a Tensor or a ProxyValue but got {type(args[0])}"  | 
 | 1845 | + | 
 | 1846 | +        # get the "to" memory format for the EdgeOp  | 
 | 1847 | +        contiguous_dim_order = list(range(ndim))  | 
 | 1848 | +        dim_order = nkwargs.pop("dim_order", None)  | 
 | 1849 | + | 
 | 1850 | +        # Cadence only supports contiguous memory format  | 
 | 1851 | +        assert (  | 
 | 1852 | +            dim_order is None  | 
 | 1853 | +            # pyre-ignore[6]: Incompatible parameter type  | 
 | 1854 | +            or len(dim_order) == 0  | 
 | 1855 | +            or dim_order == contiguous_dim_order  | 
 | 1856 | +        ), "Expected dim order in congituous or prevserve memory format, but got {}".format(  | 
 | 1857 | +            dim_order  | 
 | 1858 | +        )  | 
 | 1859 | + | 
 | 1860 | +        # bring back memory format  | 
 | 1861 | +        # pyre-ignore[6]: Incompatible parameter type  | 
 | 1862 | +        nkwargs["memory_format"] = get_memory_format(dim_order)  | 
 | 1863 | + | 
 | 1864 | +        memory_format_op = MemoryFormatOpsMap[op]  | 
 | 1865 | + | 
 | 1866 | +        return super().call_operator(  | 
 | 1867 | +            memory_format_op,  | 
 | 1868 | +            args,  | 
 | 1869 | +            nkwargs,  | 
 | 1870 | +            meta,  | 
 | 1871 | +        )  | 
 | 1872 | + | 
 | 1873 | + | 
1802 | 1874 | @register_cadence_pass(CadencePassAttribute(opt_level=0))  | 
1803 | 1875 | class ReplaceFullLikeWithFullPass(ExportPass):  | 
1804 | 1876 |     """  | 
@@ -2108,4 +2180,5 @@ class CadenceReplaceOpsInGraph:  | 
2108 | 2180 |         ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,  | 
2109 | 2181 |         ReplaceAtenAvgPoolWithJarvisAvgPoolPass,  | 
2110 | 2182 |         ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,  | 
 | 2183 | +        ReplaceToDimOrderCopyWithToCopyPass,  | 
2111 | 2184 |     ]  | 
0 commit comments