Skip to content

Commit 3990787

Browse files
committed
fix bugs in split converter
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 9ff9c22 commit 3990787

File tree

5 files changed

+94
-19
lines changed

5 files changed

+94
-19
lines changed

core/conversion/converters/impl/select.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
1818
auto in = args[0].ITensor();
1919
auto axis = args[2].unwrapToInt();
2020
auto inDimSize = in->getDimensions().d[axis];
21-
auto numOutputs = 1;
21+
auto numOutputs = 1, numRemainder = 0;
2222
std::vector<int64_t> sizes;
2323

2424
if (split_list) {
@@ -27,10 +27,13 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
2727
} else {
2828
auto split_size = args[1].unwrapToInt();
2929
numOutputs = inDimSize / split_size;
30-
if (numOutputs == 1) {
30+
numRemainder = inDimSize % split_size;
31+
for (int i = 0; i < numOutputs; i++) {
3132
sizes.push_back(split_size);
32-
} else {
33-
sizes = std::vector<int64_t>(numOutputs, 1);
33+
}
34+
if (numRemainder) {
35+
numOutputs += 1;
36+
sizes.push_back(numRemainder);
3437
}
3538
}
3639

core/conversion/var/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ cc_library(
1919
deps = [
2020
"@tensorrt//:nvinfer",
2121
"//core/util:prelude",
22-
"//core/conversion/converters:weights"
22+
"//core/conversion/converters:weights",
23+
"//core/conversion/tensorcontainer:tensorcontainer"
2324
] + select({
2425
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2526
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/var/Var.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,24 +90,37 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
9090
LOG_DEBUG(ctx->logger, "Found IValue containing object of type " << *(ptr_.ivalue->type()));
9191
}
9292
TRTORCH_CHECK(
93-
isITensor() || (isIValue() && ptr_.ivalue->isTensor()),
93+
isITensor() || (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isCustomClass())),
9494
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());
9595

9696
nvinfer1::ITensor* out;
97-
97+
auto weights = converters::Weights();
9898
if (isIValue()) {
99-
auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor());
100-
101-
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
102-
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");
103-
104-
out = const_layer->getOutput(0);
105-
106-
std::ostringstream tensor_id;
107-
tensor_id << reinterpret_cast<int*>(out);
108-
109-
LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
110-
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());
99+
if (ptr_.ivalue->isTensor()) {
100+
auto tensor = ptr_.ivalue->toTensor();
101+
if (tensor.scalar_type() == at::kLong) {
102+
weights = converters::Weights(ctx, tensor.toType(at::kInt));
103+
} else if (tensor.scalar_type() == at::kDouble) {
104+
weights = converters::Weights(ctx, tensor.toType(at::kFloat));
105+
} else {
106+
weights = converters::Weights(ctx, tensor);
107+
}
108+
109+
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
110+
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");
111+
112+
out = const_layer->getOutput(0);
113+
114+
std::ostringstream tensor_id;
115+
tensor_id << reinterpret_cast<int*>(out);
116+
117+
LOG_DEBUG(ctx->logger, "Freezing tensor " << tensor_id.str() << " as an IConstantLayer");
118+
const_layer->setName(("[Freeze Tensor " + tensor_id.str() + " ]").c_str());
119+
} else {
120+
// Split converter generates c10::IValue which hold TensorContainer.
121+
auto output_container = ptr_.ivalue->toCustomClass<TensorContainer>();
122+
out = output_container.get()->tensor();
123+
}
111124
} else {
112125
out = ptr_.tensor;
113126
}

core/conversion/var/Var.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "core/conversion/conversionctx/ConversionCtx.h"
77
#include "core/conversion/converters/Weights.h"
8+
#include "core/conversion/tensorcontainer/TensorContainer.h"
89
#include "core/util/prelude.h"
910
#include "torch/csrc/jit/ir/ir.h"
1011

tests/core/conversion/converters/test_select.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,60 @@ TEST(Converters, ATenSplitFixedConvertsCorrectly) {
288288
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
289289
}
290290
}
291+
292+
TEST(Converters, ATenSplitFixedHasRemainderConvertsCorrectly) {
293+
const auto graph = R"IR(
294+
graph(%argument_1.1 : Tensor):
295+
%2 : int = prim::Constant[value=2]()
296+
%2.1 : int = prim::Constant[value=1]()
297+
%3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1)
298+
%4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3)
299+
return (%4, %5, %6))IR";
300+
301+
auto g = std::make_shared<torch::jit::Graph>();
302+
303+
torch::jit::parseIR(graph, &*g);
304+
305+
auto in = at::randint(1, 10, {1, 5, 4, 4}, {at::kCUDA});
306+
307+
auto jit_in = at::clone(in);
308+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
309+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
310+
311+
auto trt_in = at::clone(in);
312+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
313+
314+
for (size_t i = 0; i < jit_results.size(); i++) {
315+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
316+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
317+
}
318+
}
319+
320+
TEST(Converters, ATenSplitAndAddConvertsCorrectly) {
321+
const auto graph = R"IR(
322+
graph(%argument_1.1 : Tensor):
323+
%2 : int = prim::Constant[value=2]()
324+
%2.1 : int = prim::Constant[value=1]()
325+
%3 : Tensor[] = aten::split(%argument_1.1, %2, %2.1)
326+
%4 : Tensor, %5 : Tensor = prim::ListUnpack(%3)
327+
%6 : Tensor = aten::add(%4, %5, %2.1)
328+
return (%6))IR";
329+
330+
auto g = std::make_shared<torch::jit::Graph>();
331+
332+
torch::jit::parseIR(graph, &*g);
333+
334+
auto in = at::randint(1, 10, {1, 4, 4, 4}, {at::kCUDA});
335+
336+
auto jit_in = at::clone(in);
337+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
338+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
339+
340+
auto trt_in = at::clone(in);
341+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
342+
343+
for (size_t i = 0; i < jit_results.size(); i++) {
344+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
345+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
346+
}
347+
}

0 commit comments

Comments
 (0)