|  | 
|  | 1 | +# Copyright 2025 Arm Limited and/or its affiliates. | 
|  | 2 | +# | 
|  | 3 | +# This source code is licensed under the BSD-style license found in the | 
|  | 4 | +# LICENSE file in the root directory of this source tree. | 
|  | 5 | + | 
|  | 6 | +import torch | 
|  | 7 | +from executorch.exir.pass_base import ExportPass | 
|  | 8 | + | 
|  | 9 | +torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,) | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +class DecomposeCosineSimilarityPass(ExportPass): | 
|  | 13 | +    """ | 
|  | 14 | +    Decomposition of aten.cosine_similarity: | 
|  | 15 | +
 | 
|  | 16 | +      dot    = sum(mul(x1, x2), dims, keepdim=False) | 
|  | 17 | +      norm   = pow( sum(mul(x, x), dims, keepdim=False), 0.5 ) | 
|  | 18 | +      eps    = full( (), eps_scalar ) | 
|  | 19 | +      n1c    = max(norm1, eps) | 
|  | 20 | +      n2c    = max(norm2, eps) | 
|  | 21 | +      denom  = mul(n1c, n2c) | 
|  | 22 | +      out    = div(dot, denom) | 
|  | 23 | +    """ | 
|  | 24 | + | 
|  | 25 | +    def call_operator(self, op, args, kwargs, meta): | 
|  | 26 | +        if op not in torch_cosine_similarity: | 
|  | 27 | +            return super().call_operator(op, args, kwargs, meta) | 
|  | 28 | + | 
|  | 29 | +        x1, x2 = args[0], args[1] | 
|  | 30 | +        dim = kwargs.get("dim", 1) | 
|  | 31 | +        eps = kwargs.get("eps", 1e-8) | 
|  | 32 | +        dims = [dim] if isinstance(dim, int) else list(dim) | 
|  | 33 | + | 
|  | 34 | +        # 1) dot | 
|  | 35 | +        prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta) | 
|  | 36 | +        dot = super().call_operator( | 
|  | 37 | +            torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta | 
|  | 38 | +        ) | 
|  | 39 | + | 
|  | 40 | +        # 2a) norm1 = pow(sum(x1*x1), 0.5) | 
|  | 41 | +        x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta) | 
|  | 42 | +        s1 = super().call_operator( | 
|  | 43 | +            torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta | 
|  | 44 | +        ) | 
|  | 45 | +        norm1 = super().call_operator( | 
|  | 46 | +            torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta | 
|  | 47 | +        ) | 
|  | 48 | + | 
|  | 49 | +        # 2b) norm2 = pow(sum(x2*x2), 0.5) | 
|  | 50 | +        x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta) | 
|  | 51 | +        s2 = super().call_operator( | 
|  | 52 | +            torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta | 
|  | 53 | +        ) | 
|  | 54 | +        norm2 = super().call_operator( | 
|  | 55 | +            torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta | 
|  | 56 | +        ) | 
|  | 57 | + | 
|  | 58 | +        # 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar | 
|  | 59 | +        eps_t = super().call_operator( | 
|  | 60 | +            torch.ops.aten.full_like.default, (norm1, eps), {}, meta | 
|  | 61 | +        ) | 
|  | 62 | + | 
|  | 63 | +        # 4) clamp to avoid zero division | 
|  | 64 | +        n1c = super().call_operator( | 
|  | 65 | +            torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta | 
|  | 66 | +        ) | 
|  | 67 | +        n2c = super().call_operator( | 
|  | 68 | +            torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta | 
|  | 69 | +        ) | 
|  | 70 | + | 
|  | 71 | +        # 5) denom and divide | 
|  | 72 | +        denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta) | 
|  | 73 | +        out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta) | 
|  | 74 | + | 
|  | 75 | +        return out | 
0 commit comments