Skip to content

Commit c8dc7ad

Browse files
authored
Merge pull request #308 from NVIDIA/fix_num_to_tensor
Fix issue with num_to_tensor evaluator
2 parents 4d3ac4f + 4e6d51b commit c8dc7ad

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

core/conversion/evaluators/prim.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ auto prim_registrations =
3030
}})
3131
.evaluator({torch::jit::prim::NumToTensor,
3232
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
33-
return at::scalar_to_tensor(args.at(n->output(0)).IValue()->toScalar());
33+
return at::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
3434
}})
3535
.evaluator({torch::jit::prim::ListUnpack,
3636
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

tests/core/conversion/evaluators/test_prim_evaluators.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,19 @@ TEST(Evaluators, PrimListUnpackEvaluatesCorrectly) {
3737
ASSERT_TRUE(jit_results[0] == trt_results[0]);
3838
ASSERT_TRUE(jit_results[1] == trt_results[1]);
3939
}
40+
41+
TEST(Evaluators, NumToTensorEvaluatesCorrectly) {
42+
const auto graph = R"IR(
43+
graph():
44+
%1 : int = prim::Constant[value=3]()
45+
%lu.1 : Tensor = prim::NumToTensor(%1)
46+
return (%lu.1))IR";
47+
48+
auto g = std::make_shared<torch::jit::Graph>();
49+
torch::jit::parseIR(graph, &*g);
50+
51+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
52+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
53+
54+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
55+
}

0 commit comments

Comments
 (0)