From 0fe0696a0aa83c057f44999a28d5cd2e2a5e46eb Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 6 Nov 2024 08:24:25 -0800 Subject: [PATCH] Add trunc scalar prim_op (#6580) Summary: Add a primitive op kernel for scalar trunc (double -> int). This corresponds to python math.trunc in the export graph and is used in some torchvision transforms for size calculations. Note that trunc is already supported in export. Reviewed By: mcr229 Differential Revision: D65057149 --- exir/passes/executorch_prim_ops_registry.py | 9 +++++++++ kernels/prim_ops/register_prim_ops.cpp | 16 ++++++++++++++++ kernels/prim_ops/test/prim_ops_test.cpp | 20 ++++++++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/exir/passes/executorch_prim_ops_registry.py b/exir/passes/executorch_prim_ops_registry.py index 6362a471121..4af233aaa66 100644 --- a/exir/passes/executorch_prim_ops_registry.py +++ b/exir/passes/executorch_prim_ops_registry.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import operator from typing import Dict, Set, Union @@ -14,6 +15,8 @@ from torch._ops import OpOverload from torch.library import Library +# pyre-unsafe + executorch_prims_lib = Library("executorch_prim", "DEF") @@ -91,7 +94,13 @@ def neg(a: _SymScalar) -> _SymScalar: return -a # pyre-ignore +@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar") +def trunc(a: _SymScalar) -> _SymScalar: + return math.trunc(a) # pyre-ignore + + _PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = { + math.trunc: ops.backend.executorch_prim.trunc.Scalar, operator.sub: ops.backend.executorch_prim.sub.Scalar, operator.mul: ops.backend.executorch_prim.mul.Scalar, operator.add: ops.backend.executorch_prim.add.Scalar, diff --git a/kernels/prim_ops/register_prim_ops.cpp b/kernels/prim_ops/register_prim_ops.cpp index 7872b0d173f..5755ab8d66e 100644 --- a/kernels/prim_ops/register_prim_ops.cpp +++ b/kernels/prim_ops/register_prim_ops.cpp @@ -12,6 +12,8 @@ #include #include +#include + using torch::executor::function::et_copy_index; namespace torch { @@ -301,6 +303,20 @@ static Kernel prim_ops[] = { } }), + // trunc.Scalar(Scalar a) -> Scalar + Kernel( + "executorch_prim::trunc.Scalar", + [](KernelRuntimeContext& context, EValue** stack) { + (void)context; + EValue& a = *stack[0]; + EValue& out = *stack[1]; + if (a.isDouble()) { + out = EValue(static_cast(trunc(a.toDouble()))); + } else { + ET_CHECK_MSG(false, "%zu", (size_t)a.tag); + } + }), + // executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor Kernel("executorch_prim::et_copy_index.tensor", &et_copy_index), // executorch_prim::et_view.default(Tensor, int[]) -> Tensor diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index 4b4b35a2324..3581a470da7 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -503,5 +503,25 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) { getOpsFn("executorch_prim::et_view.default")(context, bad_stack), ""); } +TEST_F(RegisterPrimOpsTest, TestTrunc) { + std::array inputs = { + 0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999}; + std::array expected = {0, 0, 0, 0, 1, 1, 0, -1, -1, 9}; + + for (auto i = 0; i < inputs.size(); i++) { + EValue values[2]; + values[0] = EValue(inputs[i]); + values[1] = EValue(0.0); + + EValue* stack[2]; + for (size_t j = 0; j < 2; j++) { + stack[j] = &values[j]; + } + + getOpsFn("executorch_prim::trunc.Scalar")(context, stack); + EXPECT_EQ(stack[1]->toInt(), expected[i]); + } +} + } // namespace executor } // namespace torch