Skip to content

Commit b2c8f59

Browse files
committed
chore: refactor code
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 4c17994 commit b2c8f59

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,23 @@ namespace conversion {
1919
namespace evaluators {
2020
namespace {
2121

22-
nvinfer1::ITensor* index_layer(){
23-
22+
nvinfer1::ITensor* index_layer(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* input_tensor, int64_t index){
23+
// index to access needs to be an at::Tensor
24+
at::Tensor indices = torch::tensor({index}).to(torch::kI32);
25+
auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices);
26+
27+
auto gather_layer = ctx->net->addGather(*input_tensor, *indices_out, 0);
28+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
29+
auto indexed_tensor = gather_layer->getOutput(0);
30+
return indexed_tensor;
2431
}
2532

2633
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){
2734
LOG_DEBUG("Using dynamic version of aten::size evaluator");
2835
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
2936
LOG_DEBUG("Input dimensions: " << in->getDimensions());
3037
auto shape_layer = ctx->net->addShape(*in);
38+
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
3139
auto shape_1d_tensor = shape_layer->getOutput(0);
3240

3341
if (n->inputs().size() != 1){
@@ -36,15 +44,9 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw
3644
// Handle negative axis by refering to nbDims of input Tensor
3745
dim = dim < 0 ? dim + maxDim : dim;
3846
LOG_DEBUG("Dimension to select: " << dim);
39-
40-
// index to access needs to be an at::Tensor
41-
at::Tensor indices = torch::tensor({dim}).to(torch::kI32);
42-
auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices);
43-
44-
auto gather_layer = ctx->net->addGather(*shape_1d_tensor, *indices_out, 0);
45-
shape_1d_tensor = gather_layer->getOutput(0);
47+
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
4648
}
47-
49+
4850
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
4951

5052
auto tensor_holder = TensorContainer();
@@ -364,13 +366,13 @@ auto aten_registrations TORCHTRT_UNUSED =
364366
TORCHTRT_CHECK(
365367
normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)");
366368
return list.get(normalized_idx);
367-
} elif (list_input.isITensor()){
368-
return dynamic_size_layer(ctx, n, args);
369+
} else if(list_input.isITensor()){
370+
auto indexed_tensor = index_layer(ctx, n, list_input.ITensorOrFreeze(ctx), idx);
371+
auto tensor_holder = TensorContainer();
372+
tensor_holder.hold_tensor(indexed_tensor);
373+
auto indexed_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
374+
return indexed_ivalue;
369375
}
370-
371-
372-
373-
374376
},
375377
EvalOptions().validSchemas({
376378
"aten::__getitem__.t(t[](a) list, int idx) -> (t(*))",

core/conversion/evaluators/prim.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ auto prim_registrations =
4848
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4949
const auto num_inputs = n->inputs().size();
5050
if (constTypesOnly(args)) {
51-
LOG_DEBUG("==== CONST TYPES ARGS ==== ");
5251
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
5352
if (torch::jit::IntType::get() == lt->getElementType()) {
5453
c10::List<int64_t> list;
@@ -106,8 +105,7 @@ auto prim_registrations =
106105
auto ival = torch::jit::IValue();
107106
list.emplace_back(std::move(ival));
108107
} else if (args.at(in).IValue()->isInt()) {
109-
LOG_DEBUG("==== INT TYPE ITENSOR ==== ");
110-
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor({args.at(in).unwrapToInt()}));
108+
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor({args.at(in).unwrapToInt()}).to(torch::kI32));
111109
auto tensor_holder = TensorContainer();
112110
tensor_holder.hold_tensor(itensor);
113111
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));

0 commit comments

Comments
 (0)