Skip to content

Commit 096fd41

Browse files
add support for aten::reciprocal(int) (#1308)
1 parent b25738e commit 096fd41

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

core/conversion/converters/impl/unary.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ auto abs_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern
4949
}
5050
}});
5151

52+
auto reciprocal_registration TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern(
53+
{"aten::reciprocal(Tensor self) -> Tensor", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
54+
auto in = args[0].ITensorOrFreeze(ctx);
55+
if (in->getType() == nvinfer1::DataType::kINT32) {
56+
// pytorch implicitly casts to float for aten::reciprocal(int)
57+
in = castITensor(ctx, in, nvinfer1::DataType::kFLOAT);
58+
}
59+
auto unary_layer = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kRECIP);
60+
TORCHTRT_CHECK(unary_layer, "Unable to create recip layer from node: " << *n);
61+
unary_layer->setName(util::node_info(n).c_str());
62+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], unary_layer->getOutput(0));
63+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
64+
return true;
65+
}});
66+
5267
#define convert(unary, trt_type) \
5368
auto unary##_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( \
5469
{"aten::" #unary "(Tensor self) -> Tensor", \
@@ -74,7 +89,6 @@ convert(sinh, kSINH);
7489
convert(tan, kTAN);
7590
convert(atan, kATAN);
7691
convert(floor, kFLOOR);
77-
convert(reciprocal, kRECIP);
7892
convert(log, kLOG);
7993
convert(ceil, kCEIL);
8094
convert(sqrt, kSQRT);

tests/core/conversion/converters/test_unary.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,22 @@ TEST(Converters, ATenAbsIntConvertsCorrectly) {
3131
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));
3232
}
3333

34+
TEST(Converters, ATenReciprocalIntConvertsCorrectly) {
35+
const auto graph = gen_test_graph("reciprocal");
36+
auto g = std::make_shared<torch::jit::Graph>();
37+
torch::jit::parseIR(graph, g.get());
38+
39+
auto in = at::tensor({-1, 1, -2, 2, -3, 3}, {at::kCUDA}).to(torch::kInt32);
40+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
41+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
42+
43+
in = at::clone(in);
44+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
45+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
46+
47+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));
48+
}
49+
3450
#define test_unary(unary, name) \
3551
TEST(Converters, ATen##name##ConvertsCorrectly) { \
3652
const auto graph = gen_test_graph(#unary); \

0 commit comments

Comments
 (0)