Skip to content

Commit 4c17994

Browse files
committed
feat: Implement dynamic version of aten::size
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 0e42d97 commit 4c17994

File tree

7 files changed

+147
-104
lines changed

7 files changed

+147
-104
lines changed

core/conversion/converters/BUILD

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,8 @@ cc_library(
6262
"impl/constant_pad.cpp",
6363
"impl/conv_deconv.cpp",
6464
"impl/cumsum.cpp",
65-
<<<<<<< HEAD
6665
"impl/dual_ops.cpp",
67-
=======
6866
"impl/einsum.cpp",
69-
>>>>>>> main
7067
"impl/element_wise.cpp",
7168
"impl/expand.cpp",
7269
"impl/interpolate.cpp",

core/conversion/converters/impl/shuffle.cpp

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -67,39 +67,33 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
6767
.pattern(
6868
{"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
6969
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
70-
auto in = args[0].ITensorOrFreeze(ctx);
71-
std::cout << "====1====" << std::endl;
72-
auto in_shape = util::toVec(in->getDimensions());
73-
std::cout << "====2====" << std::endl;
74-
std::vector<int64_t> new_shape;
75-
if (ctx->input_is_dynamic) {
76-
std::cout << "====3====: " << args[1].size() << std::endl;
77-
// new_shape = util::toVec(args[1].unwrapToIntList().vec());
78-
new_shape = util::toVec(args[1].unwrapToITensorList());
79-
std::cout << "====4====" << std::endl;
80-
int nbDynamicDims = 0;
81-
for (size_t i = 0; i < new_shape.size(); i++) {
82-
if (in_shape[i] == -1)
83-
nbDynamicDims++;
84-
}
85-
if (nbDynamicDims > 1) {
86-
TORCHTRT_THROW_ERROR(
87-
"Resize is currently not supported when target shape contains more than one dynamic dimension");
88-
}
89-
} else {
90-
std::cout << "====5====" << std::endl;
91-
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
92-
}
93-
std::cout << "====6====" << std::endl;
94-
auto shuffle = ctx->net->addShuffle(*in);
95-
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
96-
shuffle->setReshapeDimensions(util::toDims(new_shape));
97-
shuffle->setName(util::node_info(n).c_str());
98-
99-
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
100-
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
101-
102-
return true;
70+
auto in = args[0].ITensorOrFreeze(ctx);
71+
auto in_shape = util::toVec(in->getDimensions());
72+
std::vector<int64_t> new_shape;
73+
nvinfer1::ITensor* shape_tensor;
74+
if (ctx->input_is_dynamic) {
75+
auto new_shape = args[1].unwrapToITensorList();
76+
auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size());
77+
TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
78+
concat_layer->setAxis(static_cast<int32_t>(0));
79+
shape_tensor = concat_layer->getOutput(0);
80+
} else {
81+
auto new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
82+
}
83+
auto shuffle = ctx->net->addShuffle(*in);
84+
shuffle->setName(util::node_info(n).c_str());
85+
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
86+
87+
if (ctx->input_is_dynamic){
88+
shuffle->setInput(1, *shape_tensor);
89+
} else {
90+
shuffle->setReshapeDimensions(util::toDims(new_shape));
91+
}
92+
93+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
94+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
95+
96+
return true;
10397
}})
10498
.pattern(
10599
{"aten::view(Tensor(a) self, int[] size) -> (Tensor(a))",

core/conversion/evaluators/aten.cpp

Lines changed: 105 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,41 @@ namespace conversion {
1919
namespace evaluators {
2020
namespace {
2121

22+
nvinfer1::ITensor* index_layer(){
23+
24+
}
25+
26+
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args){
27+
LOG_DEBUG("Using dynamic version of aten::size evaluator");
28+
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
29+
LOG_DEBUG("Input dimensions: " << in->getDimensions());
30+
auto shape_layer = ctx->net->addShape(*in);
31+
auto shape_1d_tensor = shape_layer->getOutput(0);
32+
33+
if (n->inputs().size() != 1){
34+
auto maxDim = static_cast<int64_t>(in->getDimensions().nbDims);
35+
auto dim = args.at(n->input(1)).unwrapToInt();
36+
// Handle negative axis by refering to nbDims of input Tensor
37+
dim = dim < 0 ? dim + maxDim : dim;
38+
LOG_DEBUG("Dimension to select: " << dim);
39+
40+
// index to access needs to be an at::Tensor
41+
at::Tensor indices = torch::tensor({dim}).to(torch::kI32);
42+
auto indices_out = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, indices);
43+
44+
auto gather_layer = ctx->net->addGather(*shape_1d_tensor, *indices_out, 0);
45+
shape_1d_tensor = gather_layer->getOutput(0);
46+
}
47+
48+
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
49+
50+
auto tensor_holder = TensorContainer();
51+
tensor_holder.hold_tensor(shape_1d_tensor);
52+
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
53+
54+
return shape_1d_ivalue;
55+
}
56+
2257
DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
2358
eq,
2459
"aten::eq",
@@ -176,7 +211,7 @@ auto aten_registrations TORCHTRT_UNUSED =
176211
{c10::Symbol::fromQualString("aten::full_like"),
177212
// aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None,
178213
// Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> (Tensor)
179-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
214+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
180215
// Override options related to layout and device for TensorRT
181216
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
182217
auto input_tensor_var = args.at(n->input(0));
@@ -262,67 +297,80 @@ auto aten_registrations TORCHTRT_UNUSED =
262297
return static_cast<int64_t>(list.size());
263298
},
264299
EvalOptions().validSchemas({"aten::len.t(t[] a) -> (int)"})})
265-
// .evaluator(
266-
// {c10::Symbol::fromQualString("aten::size"),
267-
// [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
268-
// LOG_WARNING("There may be undefined behavior using dynamic shape and aten::size");
269-
// auto tensor_var = args.at(n->input(0));
270-
// if (n->inputs().size() == 1) {
271-
// if (tensor_var.isITensor()) {
272-
// auto tensor = tensor_var.ITensor();
273-
// return util::toVec(tensor->getDimensions());
274-
// } else if (tensor_var.IValue()->isTensor()) {
275-
// auto tensor = tensor_var.unwrapToTensor();
276-
// return tensor.sizes();
277-
// } else if (tensor_var.IValue()->isCustomClass()) {
278-
// auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
279-
// return util::toVec(tensor->getDimensions());
280-
// } else {
281-
// TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
282-
// }
283-
// } else {
284-
// auto dim = args.at(n->input(1)).unwrapToInt();
285-
// if (tensor_var.isITensor()) {
286-
// auto tensor = tensor_var.ITensor();
287-
// auto dims = util::toVec(tensor->getDimensions());
288-
// auto nbDims = tensor->getDimensions().nbDims;
289-
// if (dim < 0) {
290-
// dim += nbDims;
291-
// }
292-
// return dims[dim];
293-
// } else if (tensor_var.IValue()->isTensor()) {
294-
// auto tensor = tensor_var.unwrapToTensor();
295-
// auto nbDims = tensor.sizes().size();
296-
// if (dim < 0) {
297-
// dim += nbDims;
298-
// }
299-
// return tensor.sizes()[dim];
300-
// } else if (tensor_var.IValue()->isCustomClass()) {
301-
// auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
302-
// auto dims = util::toVec(tensor->getDimensions());
303-
// auto nbDims = tensor->getDimensions().nbDims;
304-
// if (dim < 0) {
305-
// dim += nbDims;
306-
// }
307-
// return dims[dim];
308-
// } else {
309-
// TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
310-
// }
311-
// }
312-
// },
313-
// EvalOptions().validSchemas(
314-
// {"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})})
300+
.evaluator(
301+
{c10::Symbol::fromQualString("aten::size"),
302+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
303+
auto tensor_var = args.at(n->input(0));
304+
if (n->inputs().size() == 1) {
305+
if (tensor_var.isITensor()) {
306+
auto tensor = tensor_var.ITensor();
307+
if (ctx->input_is_dynamic){
308+
return dynamic_size_layer(ctx, n, args);
309+
}
310+
return util::toVec(tensor->getDimensions());
311+
} else if (tensor_var.IValue()->isTensor()) {
312+
auto tensor = tensor_var.unwrapToTensor();
313+
return tensor.sizes();
314+
} else if (tensor_var.IValue()->isCustomClass()) {
315+
auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
316+
return util::toVec(tensor->getDimensions());
317+
} else {
318+
TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
319+
}
320+
} else {
321+
auto dim = args.at(n->input(1)).unwrapToInt();
322+
if (tensor_var.isITensor()) {
323+
if (ctx->input_is_dynamic){
324+
return dynamic_size_layer(ctx, n, args);
325+
}
326+
auto tensor = tensor_var.ITensor();
327+
auto dims = util::toVec(tensor->getDimensions());
328+
auto nbDims = tensor->getDimensions().nbDims;
329+
if (dim < 0) {
330+
dim += nbDims;
331+
}
332+
return dims[dim];
333+
} else if (tensor_var.IValue()->isTensor()) {
334+
auto tensor = tensor_var.unwrapToTensor();
335+
auto nbDims = tensor.sizes().size();
336+
if (dim < 0) {
337+
dim += nbDims;
338+
}
339+
return tensor.sizes()[dim];
340+
} else if (tensor_var.IValue()->isCustomClass()) {
341+
auto tensor = tensor_var.IValue()->toCustomClass<TensorContainer>()->tensor();
342+
auto dims = util::toVec(tensor->getDimensions());
343+
auto nbDims = tensor->getDimensions().nbDims;
344+
if (dim < 0) {
345+
dim += nbDims;
346+
}
347+
return dims[dim];
348+
} else {
349+
TORCHTRT_THROW_ERROR("IValue is not some class of Tensor. Found: " << tensor_var.IValue()->type());
350+
}
351+
}
352+
},
353+
EvalOptions().validSchemas(
354+
{"aten::size(Tensor self) -> (int[])", "aten::size.int(Tensor self, int dim) -> (int)"})})
315355
.evaluator(
316356
{c10::Symbol::fromQualString("aten::__getitem__"),
317357
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318-
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
358+
auto list_input = args.at(n->input(0));
319359
auto idx = args.at(n->input(1)).unwrapToInt();
360+
if (list_input.isIValue()){
361+
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
362+
const int64_t list_size = list.size();
363+
const int64_t normalized_idx = normalizeIndex(idx, list_size);
364+
TORCHTRT_CHECK(
365+
normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)");
366+
return list.get(normalized_idx);
367+
} elif (list_input.isITensor()){
368+
return dynamic_size_layer(ctx, n, args);
369+
}
370+
371+
320372

321-
const int64_t list_size = list.size();
322-
const int64_t normalized_idx = normalizeIndex(idx, list_size);
323-
TORCHTRT_CHECK(
324-
normalized_idx >= 0 || normalized_idx < list_size, "List index out of range (aten::__getitem__)");
325-
return list.get(normalized_idx);
373+
326374
},
327375
EvalOptions().validSchemas({
328376
"aten::__getitem__.t(t[](a) list, int idx) -> (t(*))",

core/conversion/evaluators/prim.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ auto prim_registrations =
4848
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4949
const auto num_inputs = n->inputs().size();
5050
if (constTypesOnly(args)) {
51+
LOG_DEBUG("==== CONST TYPES ARGS ==== ");
5152
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
5253
if (torch::jit::IntType::get() == lt->getElementType()) {
5354
c10::List<int64_t> list;
@@ -89,6 +90,7 @@ auto prim_registrations =
8990
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
9091
}
9192
} else {
93+
LOG_DEBUG("==== NON CONST TYPES ==== ");
9294
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
9395
c10::TypePtr elementType = lt->getElementType();
9496
auto list = c10::impl::GenericList(elementType);
@@ -104,7 +106,8 @@ auto prim_registrations =
104106
auto ival = torch::jit::IValue();
105107
list.emplace_back(std::move(ival));
106108
} else if (args.at(in).IValue()->isInt()) {
107-
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor(args.at(in).unwrapToInt()));
109+
LOG_DEBUG("==== INT TYPE ITENSOR ==== ");
110+
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor({args.at(in).unwrapToInt()}));
108111
auto tensor_holder = TensorContainer();
109112
tensor_holder.hold_tensor(itensor);
110113
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));

core/conversion/var/Var.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,21 +147,23 @@ bool Var::isITensor() const {
147147
}
148148

149149
bool Var::isITensorList() const {
150-
LOG_DEBUG("===== TYPE NAME: " << type_name());
151150
if (type_ == Type::kITensor) {
152151
return true;
153152
} else {
154153
return false;
155154
}
156155
}
157156

158-
bool Var::unwrapToITensorList() {
157+
std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
159158
TORCHTRT_CHECK(
160159
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
161-
LOG_DEBUG("===== TYPE NAME: " << type_name());
162-
auto ivalue = ptr_.ivalue;
163-
return false;
164-
// return ptr_.ivalue->to<nvinfer1::ITensor*>();
160+
auto ivalue_list = ptr_.ivalue->toList();
161+
std::vector<nvinfer1::ITensor*> outputs;
162+
for (int i=0; i < ivalue_list.size(); i++){
163+
auto element = ivalue_list.get(i).toCustomClass<TensorContainer>()->tensor();
164+
outputs.push_back(std::move(element));
165+
}
166+
return outputs;
165167
}
166168

167169
bool Var::isIValue() const {

core/conversion/var/Var.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class Var : torch::CustomClassHolder {
4343
c10::Scalar unwrapToScalar();
4444
c10::List<int64_t> unwrapToIntList(c10::List<int64_t> default_val);
4545
c10::List<int64_t> unwrapToIntList();
46-
c10::List<nvinfer1::ITensor*> unwrapToITensorList();
46+
std::vector<nvinfer1::ITensor*> unwrapToITensorList();
4747
c10::List<double> unwrapToDoubleList(c10::List<double> default_val);
4848
c10::List<double> unwrapToDoubleList();
4949
c10::List<bool> unwrapToBoolList(c10::List<bool> default_val);

py/requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
numpy
22
pybind11==2.6.2
3-
--extra-index-url https://download.pytorch.org/whl/nightly/cu117
4-
torch==2.0.0.dev20230103+cu117
5-
torchvision==0.15.0.dev20230103+cu117
3+
torch==1.13.0
4+
torchvision==0.14.0
65
--extra-index-url https://pypi.ngc.nvidia.com
76
tensorrt==8.5.1.7

0 commit comments

Comments
 (0)