Skip to content

Commit 45e3bd4

Browse files
committed
feat(aten::__derive_index): Implement derive index evaluator
Fixes: #834 Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 679ea21 commit 45e3bd4

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,16 @@ auto aten_registrations TORCHTRT_UNUSED =
806806
return 0;
807807
}
808808
},
809-
EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})});
809+
EvalOptions().validSchemas({"aten::__range_length(int lo, int hi, int step) -> int"})})
810+
.evaluator(
811+
{c10::Symbol::fromQualString("aten::__derive_index"),
812+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
813+
auto idx = args.at(n->input(0)).unwrapToInt();
814+
auto start = args.at(n->input(1)).unwrapToInt();
815+
auto step = args.at(n->input(2)).unwrapToInt();
816+
return start + idx * step;
817+
}});
818+
810819
} // namespace
811820
} // namespace evaluators
812821
} // namespace conversion

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,3 +797,21 @@ TEST(Evaluators, PowFloatIntEvaluatesCorrectly) {
797797

798798
ASSERT_TRUE(jit_results[0] == trt_results[0]);
799799
}
800+
801+
TEST(Evaluators, DeriveIndexEvaluatesCorrectly) {
802+
const auto graph = R"IR(
803+
graph():
804+
%1 : int = prim::Constant[value=9]()
805+
%2 : int = prim::Constant[value=4]()
806+
%3 : int = prim::Constant[value=2]()
807+
%4 : int = aten::__derive_index(%1, %2, %3)
808+
return (%4))IR";
809+
810+
auto g = std::make_shared<torch::jit::Graph>();
811+
torch::jit::parseIR(graph, g.get());
812+
813+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
814+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
815+
816+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
817+
}

0 commit comments

Comments
 (0)