Skip to content

Commit 10aaaf4

Browse files
committed
chore: Refactor utilities and support new Var utils and testcase
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 6c69b41 commit 10aaaf4

File tree

6 files changed

+89
-47
lines changed

6 files changed

+89
-47
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,47 +19,6 @@ namespace conversion {
1919
namespace evaluators {
2020
namespace {
2121

22-
nvinfer1::ITensor* index_layer(
23-
ConversionCtx* ctx,
24-
const torch::jit::Node* n,
25-
nvinfer1::ITensor* input_tensor,
26-
int64_t index) {
27-
// index to access needs to be an at::Tensor
28-
at::Tensor indices = torch::tensor({index}).to(torch::kI32);
29-
auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices);
30-
31-
auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0);
32-
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
33-
auto indexed_tensor = gather_layer->getOutput(0);
34-
return indexed_tensor;
35-
}
36-
37-
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
38-
LOG_DEBUG("Using dynamic version of aten::size evaluator");
39-
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
40-
LOG_DEBUG("Input dimensions: " << in->getDimensions());
41-
auto shape_layer = ctx->net->addShape(*in);
42-
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
43-
auto shape_1d_tensor = shape_layer->getOutput(0);
44-
45-
if (n->inputs().size() != 1) {
46-
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
47-
auto dim = args.at(n->input(1)).unwrapToInt();
48-
// Handle negative axis by refering to nbDims of input Tensor
49-
dim = dim < 0 ? dim + maxDim : dim;
50-
LOG_DEBUG("Dimension to select: " << dim);
51-
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
52-
}
53-
54-
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
55-
56-
auto tensor_holder = TensorContainer();
57-
tensor_holder.hold_tensor(shape_1d_tensor);
58-
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
59-
60-
return shape_1d_ivalue;
61-
}
62-
6322
DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
6423
eq,
6524
"aten::eq",

core/conversion/evaluators/eval_util.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "core/conversion/evaluators/eval_util.h"
12
#include <ATen/ATen.h>
23
#include "ATen/InitialTensorOptions.h"
34
#include "ATen/core/List.h"
@@ -6,12 +7,54 @@
67
#include "ATen/core/jit_type.h"
78
#include "c10/util/irange.h"
89
#include "core/util/prelude.h"
10+
#include "torch/torch.h"
911

1012
namespace torch_tensorrt {
1113
namespace core {
1214
namespace conversion {
1315
namespace evaluators {
1416

17+
nvinfer1::ITensor* index_layer(
18+
ConversionCtx* ctx,
19+
const torch::jit::Node* n,
20+
nvinfer1::ITensor* input_tensor,
21+
int64_t index) {
22+
// index to access needs to be an at::Tensor
23+
at::Tensor indices = torch::tensor({index}).to(torch::kI32);
24+
auto indices_out = converters::tensor_to_const(ctx, indices);
25+
26+
auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0);
27+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
28+
auto indexed_tensor = gather_layer->getOutput(0);
29+
return indexed_tensor;
30+
}
31+
32+
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
33+
LOG_DEBUG("Using dynamic version of aten::size evaluator");
34+
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
35+
LOG_DEBUG("Input dimensions: " << in->getDimensions());
36+
auto shape_layer = ctx->net->addShape(*in);
37+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
38+
auto shape_1d_tensor = shape_layer->getOutput(0);
39+
40+
if (n->inputs().size() != 1) {
41+
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
42+
auto dim = args.at(n->input(1)).unwrapToInt();
43+
// Handle negative axis by refering to nbDims of input Tensor
44+
dim = dim < 0 ? dim + maxDim : dim;
45+
LOG_DEBUG("Dimension to select: " << dim);
46+
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
47+
}
48+
49+
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
50+
51+
auto tensor_holder = TensorContainer();
52+
tensor_holder.hold_tensor(shape_1d_tensor);
53+
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
54+
55+
return shape_1d_ivalue;
56+
}
57+
1558
int64_t normalizeIndex(int64_t idx, int64_t list_size) {
1659
if (idx < 0) {
1760
// Handle negative indexing
@@ -128,7 +171,7 @@ void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) {
128171
}
129172

130173
// TODO: Conditionally enable truncation based on user setting
131-
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device = at::kCPU) {
174+
at::Tensor scalar_to_tensor(const at::Scalar& s, const at::Device device) {
132175
// This function is basically same with the one in
133176
// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/ScalarOps.h, what different here is that Int and Float
134177
// won't be upgraded to kDouble or kLong since we don't support these 2 types in conversion

core/conversion/evaluators/eval_util.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
#pragma once
22

3+
#include "core/conversion/evaluators/evaluators.h"
34
#include "torch/csrc/jit/ir/ir.h"
45

56
namespace torch_tensorrt {
67
namespace core {
78
namespace conversion {
89
namespace evaluators {
910

11+
nvinfer1::ITensor* index_layer(
12+
ConversionCtx* ctx,
13+
const torch::jit::Node* n,
14+
nvinfer1::ITensor* input_tensor,
15+
int64_t index);
16+
17+
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);
18+
1019
c10::optional<torch::jit::IValue> toIValue(const torch::jit::Value* v);
1120
at::Tensor createTensorFromList(
1221
const torch::jit::IValue& data,

core/conversion/evaluators/prim.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,8 @@ auto prim_registrations =
8888
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
8989
}
9090
} else {
91-
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
92-
c10::TypePtr elementType = lt->getElementType();
93-
auto list = c10::impl::GenericList(elementType);
91+
// List would be of IValues (with ITensors embedded in them)
92+
auto list = c10::impl::GenericList(c10::AnyType::get());
9493
list.reserve(num_inputs);
9594
for (auto in : n->inputs()) {
9695
if (args.at(in).isITensor()) {

core/conversion/var/Var.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,10 @@ bool Var::isBoolList() const {
218218
std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
219219
TORCHTRT_CHECK(
220220
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
221-
LOG_DEBUG(" === Is INT list: " << ptr_.ivalue->isIntList());
222-
LOG_DEBUG(" === Is List: " << ptr_.ivalue->isList());
221+
TORCHTRT_CHECK(
222+
isITensorList(),
223+
"Expected IValue to be an ITensorList, however the type is "
224+
<< static_cast<std::underlying_type<IValueType>::type>(ivalue_type_));
223225
auto ivalue_list = ptr_.ivalue->toList();
224226
std::vector<nvinfer1::ITensor*> outputs;
225227
for (int i = 0; i < ivalue_list.size(); i++) {

tests/cpp/test_dynamic_size.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,33 @@ TEST(Converters, ATenResizeDynamicInputCorrectly) {
5959

6060
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
6161
}
62+
63+
TEST(Converters, ATenResizeGetItemDynShapeCorrectly) {
64+
const auto graph = R"IR(
65+
graph(%x.1 : Tensor):
66+
%3 : int = prim::Constant[value=-1]()
67+
%2 : int = prim::Constant[value=0]()
68+
%size.1 : int[] = aten::size(%x.1)
69+
%37 : int = aten::__getitem__(%size.1, %2)
70+
%39 : int[] = prim::ListConstruct(%37, %3)
71+
%7 : Tensor = aten::reshape(%x.1, %39)
72+
return (%7))IR";
73+
74+
auto g = std::make_shared<torch::jit::Graph>();
75+
76+
torch::jit::parseIR(graph, g.get());
77+
78+
auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA});
79+
80+
auto jit_in = at::clone(in);
81+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
82+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
83+
84+
auto trt_in = at::clone(in);
85+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
86+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
87+
88+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
89+
90+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
91+
}

0 commit comments

Comments
 (0)