Skip to content

Commit fab90d1

Browse files
committed
Add trunc scalar prim_op (pytorch#6580)
Summary: Add a primitive op kernel for scalar trunc (double -> int). This corresponds to python math.trunc in the export graph and is used in some torchvision transforms for size calculations. Note that trunc is already supported in export. Test Plan: Added a test for trunc in prim ops tests. Also validated end to end with MSGR ODNC model. Differential Revision: D65057149 Pulled By: GregoryComer
1 parent 17ad8d3 commit fab90d1

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

exir/passes/executorch_prim_ops_registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
import operator
89
from typing import Dict, Set, Union
910

@@ -91,7 +92,13 @@ def neg(a: _SymScalar) -> _SymScalar:
9192
return -a # pyre-ignore
9293

9394

95+
@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar")
96+
def trunc(a: _SymScalar) -> _SymScalar:
97+
return math.trunc(a) # pyre-ignore
98+
99+
94100
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = {
101+
math.trunc: ops.backend.executorch_prim.trunc.Scalar,
95102
operator.sub: ops.backend.executorch_prim.sub.Scalar,
96103
operator.mul: ops.backend.executorch_prim.mul.Scalar,
97104
operator.add: ops.backend.executorch_prim.add.Scalar,

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313
#include <executorch/runtime/kernel/operator_registry.h>
1414

15+
#include <cmath>
16+
1517
using torch::executor::function::et_copy_index;
1618

1719
namespace torch {
@@ -301,6 +303,20 @@ static Kernel prim_ops[] = {
301303
}
302304
}),
303305

306+
// trunc.Scalar(Scalar a) -> Scalar
307+
Kernel(
308+
"executorch_prim::trunc.Scalar",
309+
[](KernelRuntimeContext& context, EValue** stack) {
310+
(void)context;
311+
EValue& a = *stack[0];
312+
EValue& out = *stack[1];
313+
if (a.isDouble()) {
314+
out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
315+
} else {
316+
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
317+
}
318+
}),
319+
304320
// executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor
305321
Kernel("executorch_prim::et_copy_index.tensor", &et_copy_index),
306322
// executorch_prim::et_view.default(Tensor, int[]) -> Tensor

kernels/prim_ops/test/prim_ops_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,5 +503,25 @@ TEST_F(RegisterPrimOpsTest, TestETViewEmpty) {
503503
getOpsFn("executorch_prim::et_view.default")(context, bad_stack), "");
504504
}
505505

506+
TEST_F(RegisterPrimOpsTest, TestTrunc) {
507+
std::array<double, 10> inputs = {
508+
0.0, 0.25, 0.5, 0.75, 1.0, 1.75, -0.5, -1.0, -1.5, 9.999999};
509+
std::array<int64_t, 10> expected = {0, 0, 0, 0, 1, 1, 0, -1, -1, 9};
510+
511+
for (auto i = 0; i < inputs.size(); i++) {
512+
EValue values[2];
513+
values[0] = EValue(inputs[i]);
514+
values[1] = EValue(0.0);
515+
516+
EValue* stack[2];
517+
for (size_t i = 0; i < 2; i++) {
518+
stack[i] = &values[i];
519+
}
520+
521+
getOpsFn("executorch_prim::trunc.Scalar")(context, stack);
522+
EXPECT_EQ(stack[1]->toInt(), expected[i]);
523+
}
524+
}
525+
506526
} // namespace executor
507527
} // namespace torch

0 commit comments

Comments
 (0)