Skip to content

Commit 74f4a26

Browse files
committed
Remove tensorlist output support changes
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 33bedc8 commit 74f4a26

File tree

2 files changed

+10
-25
lines changed

2 files changed

+10
-25
lines changed

core/conversion/conversion.cpp

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -180,22 +180,7 @@ void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outp
180180
if (it == ctx->value_tensor_map.end()) {
181181
if (ctx->evaluated_value_map.find(out) != ctx->evaluated_value_map.end()) {
182182
auto out_ivalue = ctx->evaluated_value_map[out];
183-
if (out_ivalue.isList()) {
184-
auto output_list = out_ivalue.toList();
185-
LOG_DEBUG("One of the outputs is a TensorList. output_list size: " << output_list.size());
186-
187-
for (int i = 0; i < output_list.size(); i++) {
188-
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
189-
auto output_container = output_list.get(i).toCustomClass<TensorContainer>();
190-
nvinfer1::ITensor* out_tensor = output_container.get()->tensor();
191-
out_tensor->setName(name.c_str());
192-
ctx->net->markOutput(*out_tensor);
193-
LOG_INFO(
194-
ctx->logger,
195-
"Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
196-
ctx->num_outputs += 1;
197-
}
198-
} else if (out_ivalue.isCustomClass()) {
183+
if (out_ivalue.isCustomClass()) {
199184
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
200185
auto output_container = out_ivalue.toCustomClass<TensorContainer>();
201186
nvinfer1::ITensor* out_tensor = output_container.get()->tensor();
@@ -374,8 +359,8 @@ void ConvertBlockToNetDef(
374359
auto eval_list = eval.value().toTuple();
375360
TRTORCH_CHECK(
376361
eval_list->elements().size() == n->outputs().size(),
377-
"Size of evaluated results: " << eval_list->elements().size() << " and node outputs size: " << n->outputs().size()
378-
<< " must match.");
362+
"Size of evaluated results: " << eval_list->elements().size()
363+
<< " and node outputs size: " << n->outputs().size() << " must match.");
379364
for (int i = 0; i < eval_list->elements().size(); i++) {
380365
auto eval_output = eval_list.get()->elements()[i];
381366
LOG_DEBUG(
@@ -384,8 +369,7 @@ void ConvertBlockToNetDef(
384369
ctx->AssociateValueAndIValue(n->output(i), eval_output);
385370
}
386371
} else {
387-
TRTORCH_THROW_ERROR(
388-
"Unsupported return type for evaluated node");
372+
TRTORCH_THROW_ERROR("Unsupported return type for evaluated node");
389373
}
390374
} else if (!eval.value().isTensor()) {
391375
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());

tests/util/evaluate_graph.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#include "core/conversion/converters/converters.h"
66
#include "core/conversion/evaluators/evaluators.h"
77
#include "core/conversion/var/Var.h"
8-
#include "core/util/prelude.h"
98
#include "core/util/jit_util.h"
9+
#include "core/util/prelude.h"
1010

1111
namespace trtorch {
1212
namespace tests {
@@ -30,17 +30,18 @@ std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::v
3030
if (eval) {
3131
if (eval.value().isTuple()) {
3232
auto eval_list = eval.value().toTuple();
33-
for (int i = 0; i < eval_list->elements().size(); i++){
33+
for (int i = 0; i < eval_list->elements().size(); i++) {
3434
auto eval_output = eval_list.get()->elements()[i];
3535
LOG_DEBUG(
3636
ctx->logger,
37-
"Found the evaluated value(s) to be " << eval_output << " for node: " << trtorch::core::util::node_info(n));
37+
"Found the evaluated value(s) to be " << eval_output
38+
<< " for node: " << trtorch::core::util::node_info(n));
3839
ctx->AssociateValueAndIValue(n->output(i), eval_output);
3940
}
40-
} else if(!eval.value().isTensor()){
41+
} else if (!eval.value().isTensor()) {
4142
LOG_DEBUG("Found the value to be: " << eval.value());
4243
ctx->AssociateValueAndIValue(n->output(0), eval.value());
43-
}else {
44+
} else {
4445
LOG_DEBUG("Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
4546
ctx->AssociateValueAndIValue(n->output(0), eval.value());
4647
}

0 commit comments

Comments
 (0)