Skip to content

Commit 33bedc8

Browse files
committed
Address review changes
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 188b907 commit 33bedc8

File tree

4 files changed

+49
-15
lines changed

4 files changed

+49
-15
lines changed

core/conversion/conversion.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -370,19 +370,22 @@ void ConvertBlockToNetDef(
370370
auto eval = EvaluateNode(ctx, n);
371371
if (eval) {
372372
if (n->outputs().size() > 1) { // For ListUnpack scenario
373-
if (eval.value().isList()) {
374-
LOG_DEBUG(
375-
ctx->logger,
376-
"Found the evaluated value to be a list" << eval.value() << " for node: " << util::node_info(n));
377-
auto eval_list = eval.value().toList();
373+
if (eval.value().isTuple()) {
374+
auto eval_list = eval.value().toTuple();
378375
TRTORCH_CHECK(
379-
eval_list.size() == n->outputs().size(),
380-
"Size of evaluated results: " << eval_list.size() << " and node outputs size: " << n->outputs().size()
376+
eval_list->elements().size() == n->outputs().size(),
377+
"Size of evaluated results: " << eval_list->elements().size() << " and node outputs size: " << n->outputs().size()
381378
<< " must match.");
382-
for (int i = 0; i < eval_list.size(); i++) {
383-
auto eval_output = eval_list.get(i);
379+
for (int i = 0; i < eval_list->elements().size(); i++) {
380+
auto eval_output = eval_list.get()->elements()[i];
381+
LOG_DEBUG(
382+
ctx->logger,
383+
"Found the evaluated value(s) to be " << eval_output << " for node: " << util::node_info(n));
384384
ctx->AssociateValueAndIValue(n->output(i), eval_output);
385385
}
386+
} else {
387+
TRTORCH_THROW_ERROR(
388+
"Unsupported return type for evaluated node");
386389
}
387390
} else if (!eval.value().isTensor()) {
388391
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());

core/conversion/evaluators/prim.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ auto prim_registrations =
3636
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3737
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
3838
const torch::jit::IValue* outputs = args.at(n->input()).IValue();
39-
return *std::move(outputs);
39+
auto outputVec = outputs->toList().vec();
40+
return std::move(c10::ivalue::Tuple::create(outputVec));
4041
}})
4142
.evaluator({torch::jit::prim::ListConstruct,
4243
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

tests/core/conversion/evaluators/test_prim_evaluators.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,23 @@ TEST(Evaluators, PrimConstantEvaluatesCorrectly) {
1717
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
1818

1919
ASSERT_TRUE(jit_results[0] == trt_results[0]);
20-
}
20+
}
21+
22+
TEST(Evaluators, PrimListUnpackEvaluatesCorrectly) {
23+
const auto graph = R"IR(
24+
graph():
25+
%1 : int = prim::Constant[value=3]()
26+
%2 : int = prim::Constant[value=4]()
27+
%lc : int[] = prim::ListConstruct(%1, %2)
28+
%lu.1 : int, %lu.2 : int = prim::ListUnpack(%lc)
29+
return (%lu.1, %lu.2))IR";
30+
31+
auto g = std::make_shared<torch::jit::Graph>();
32+
torch::jit::parseIR(graph, &*g);
33+
34+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
35+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
36+
37+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
38+
ASSERT_TRUE(jit_results[1] == trt_results[1]);
39+
}

tests/util/evaluate_graph.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "core/conversion/evaluators/evaluators.h"
77
#include "core/conversion/var/Var.h"
88
#include "core/util/prelude.h"
9+
#include "core/util/jit_util.h"
910

1011
namespace trtorch {
1112
namespace tests {
@@ -20,20 +21,30 @@ std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::v
2021
for (size_t i = 0; i < inputs.size(); i++) {
2122
ctx->AssociateValueAndIValue(b->inputs()[i], inputs[i]);
2223
}
23-
24+
LOG_DEBUG("Checking nodes");
2425
for (const auto n : b->nodes()) {
2526
TRTORCH_CHECK(
2627
core::conversion::evaluators::shouldEvalAtConversionTime(n),
2728
"Test graph contains non evaluatable nodes: " << *n);
2829
auto eval = core::conversion::EvaluateNode(ctx, n);
2930
if (eval) {
30-
if (!eval.value().isTensor()) {
31+
if (eval.value().isTuple()) {
32+
auto eval_list = eval.value().toTuple();
33+
for (int i = 0; i < eval_list->elements().size(); i++){
34+
auto eval_output = eval_list.get()->elements()[i];
35+
LOG_DEBUG(
36+
ctx->logger,
37+
"Found the evaluated value(s) to be " << eval_output << " for node: " << trtorch::core::util::node_info(n));
38+
ctx->AssociateValueAndIValue(n->output(i), eval_output);
39+
}
40+
} else if(!eval.value().isTensor()){
3141
LOG_DEBUG("Found the value to be: " << eval.value());
32-
} else {
42+
ctx->AssociateValueAndIValue(n->output(0), eval.value());
43+
}else {
3344
LOG_DEBUG("Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
45+
ctx->AssociateValueAndIValue(n->output(0), eval.value());
3446
}
3547
}
36-
ctx->AssociateValueAndIValue(n->output(0), eval.value());
3748
}
3849

3950
std::vector<torch::jit::IValue> outputs;

0 commit comments

Comments
 (0)