|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | 2 | # All rights reserved. |
| 3 | +# Copyright 2025 Arm Limited and/or its affiliates. |
3 | 4 | # |
4 | 5 | # This source code is licensed under the BSD-style license found in the |
5 | 6 | # LICENSE file in the root directory of this source tree. |
|
37 | 38 | ) |
38 | 39 | from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass |
39 | 40 | from executorch.backends.cadence.aot.utils import get_edge_overload_packet |
| 41 | +from executorch.backends.transforms.replace_scalar_with_tensor import ( |
| 42 | + ReplaceScalarWithTensorArgPass, |
| 43 | +) |
40 | 44 | from executorch.exir.dialects._ops import ops as exir_ops |
41 | 45 | from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket |
42 | 46 | from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue |
@@ -1713,65 +1717,9 @@ def call_operator(self, op, args, kwargs, meta): |
1713 | 1717 | ) |
1714 | 1718 |
|
1715 | 1719 |
|
1716 | | -@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1717 | | -class ReplaceScalarWithTensorArgPass(ExportPass): |
1718 | | - """ |
1719 | | - For binary ops like add.Scalar, sub.Scalar mul.Scalar, and div.Scalar, |
1720 | | - replace the scalar arg with Tensor arg. |
1721 | | - """ |
1722 | | - |
1723 | | - scalar_to_tensor_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { |
1724 | | - exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, |
1725 | | - exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, |
1726 | | - exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, |
1727 | | - exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, |
1728 | | - } |
1729 | | - |
1730 | | - def get_replacement(self, op, args, kwargs, meta): |
1731 | | - return super().call_operator( |
1732 | | - # Replace with .Tensor variant. |
1733 | | - op=self.scalar_to_tensor_ops[op], |
1734 | | - args=( |
1735 | | - # Tensor arg. |
1736 | | - args[0], |
1737 | | - # Scalar arg - replace with aten.full tensor. |
1738 | | - super().call_operator( |
1739 | | - exir_ops.edge.aten.full.default, |
1740 | | - args=( |
1741 | | - (1,), |
1742 | | - args[1], |
1743 | | - ), |
1744 | | - kwargs={"dtype": args[0].to_tensor().dtype}, |
1745 | | - meta=meta, |
1746 | | - ), |
1747 | | - # Other args. |
1748 | | - *args[2:], |
1749 | | - ), |
1750 | | - kwargs=kwargs, |
1751 | | - meta=meta, |
1752 | | - ) |
1753 | | - |
1754 | | - def call_operator(self, op, args, kwargs, meta): |
1755 | | - if op not in self.scalar_to_tensor_ops: |
1756 | | - return super().call_operator(op, args, kwargs, meta) |
1757 | | - |
1758 | | - # There must be exactly 2 args (3 for add and sub containing alpha) |
1759 | | - assert len(args) == 2 or len(args) == 3 |
1760 | | - |
1761 | | - # If there are two args, just replace the op. |
1762 | | - if len(args) == 2: |
1763 | | - return self.get_replacement(op, args, kwargs, meta) |
1764 | | - |
1765 | | - # In case the op has three args, it must be scalar add/sub op. |
1766 | | - if ( |
1767 | | - op not in {exir_ops.edge.aten.add.Scalar, exir_ops.edge.aten.sub.Scalar} |
1768 | | - or "alpha" in kwargs |
1769 | | - ): |
1770 | | - return super().call_operator(op, args, kwargs, meta) |
1771 | | - |
1772 | | - return self.get_replacement(op, args, kwargs, meta) |
1773 | | - |
1774 | | - |
| 1720 | +@register_cadence_pass(CadencePassAttribute(opt_level=0))( |
| 1721 | + ReplaceScalarWithTensorArgPass() |
| 1722 | +) |
1775 | 1723 | @register_cadence_pass(CadencePassAttribute(opt_level=0)) |
1776 | 1724 | class ReplaceScalarTensorWithFullPass(ExportPass): |
1777 | 1725 | """ |
|
0 commit comments