Skip to content

Commit 6c69b41

Browse files
committed
chore: Add ivalue type detections
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 6754c79 commit 6c69b41

File tree

5 files changed

+116
-9
lines changed

5 files changed

+116
-9
lines changed

core/conversion/converters/impl/shuffle.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,23 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
7272
std::vector<int64_t> new_shape;
7373
nvinfer1::ITensor* shape_tensor;
7474
if (ctx->input_is_dynamic) {
75-
auto new_shape = args[1].unwrapToITensorList();
76-
auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size());
77-
TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
78-
concat_layer->setAxis(static_cast<int32_t>(0));
79-
shape_tensor = concat_layer->getOutput(0);
75+
LOG_DEBUG("Using dynamic version of reshape layer");
76+
if (args[1].isITensorList()) {
77+
LOG_DEBUG("Shape tensor is an ITensorList");
78+
auto new_shape = args[1].unwrapToITensorList();
79+
auto concat_layer = ctx->net->addConcatenation(new_shape.data(), new_shape.size());
80+
TORCHTRT_CHECK(concat_layer, "Unable to create concatenation layer from node: " << *n);
81+
concat_layer->setAxis(static_cast<int32_t>(0));
82+
shape_tensor = concat_layer->getOutput(0);
83+
} else if (args[1].isIntList()) {
84+
LOG_DEBUG("Shape tensor is an IntList");
85+
auto shape_vec = args[1].unwrapToIntList().vec();
86+
shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32));
87+
} else {
88+
LOG_ERROR(
89+
"Invalid IValue type of " << args[1].ivalue_type()
90+
<< " detected for shape tensor from node: " << *n);
91+
}
8092
} else {
8193
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
8294
}

core/conversion/evaluators/prim.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ auto prim_registrations =
8888
return c10::optional<torch::jit::IValue>(std::move(torch::jit::IValue(list)));
8989
}
9090
} else {
91-
LOG_DEBUG("==== NON CONST TYPES ==== ");
9291
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
9392
c10::TypePtr elementType = lt->getElementType();
9493
auto list = c10::impl::GenericList(elementType);

core/conversion/var/Var.cpp

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,28 @@ Var::Var(nvinfer1::ITensor* p) : type_(Type::kITensor) {
2121
ptr_.tensor = p;
2222
}
2323

24+
Var::IValueType Var::determineIValueType(torch::jit::IValue* p) {
25+
if (p->isInt()) {
26+
return IValueType::kInt;
27+
} else if (p->isDouble()) {
28+
return IValueType::kDouble;
29+
} else if (p->isBool()) {
30+
return IValueType::kBool;
31+
} else if (p->isTensor()) {
32+
return IValueType::kTensor;
33+
} else if (p->isIntList()) {
34+
return IValueType::kIntList;
35+
} else if (p->isDoubleList()) {
36+
return IValueType::kDoubleList;
37+
} else if (p->isBoolList()) {
38+
return IValueType::kBoolList;
39+
} else if (p->isTensorList()) {
40+
return IValueType::kTensorList;
41+
} else if (p->isList()) {
42+
return IValueType::kITensorList;
43+
}
44+
}
45+
2446
Var::Var(const Var& a) {
2547
switch (a.type_) {
2648
case Type::kITensor:
@@ -30,6 +52,7 @@ Var::Var(const Var& a) {
3052
case Type::kIValue:
3153
ptr_.ivalue = a.ptr_.ivalue;
3254
type_ = Type::kIValue;
55+
ivalue_type_ = determineIValueType(ptr_.ivalue);
3356
break;
3457
case Type::kNone:
3558
default:
@@ -47,6 +70,7 @@ Var& Var::operator=(const Var& a) {
4770
case Type::kIValue:
4871
ptr_.ivalue = a.ptr_.ivalue;
4972
type_ = Type::kIValue;
73+
ivalue_type_ = determineIValueType(ptr_.ivalue);
5074
break;
5175
case Type::kNone:
5276
default:
@@ -59,6 +83,7 @@ Var& Var::operator=(const Var& a) {
5983
Var& Var::operator=(torch::jit::IValue* in) {
6084
ptr_.ivalue = in;
6185
type_ = Type::kIValue;
86+
ivalue_type_ = determineIValueType(ptr_.ivalue);
6287
return (*this);
6388
}
6489

@@ -72,6 +97,10 @@ Var::Type Var::type() const {
7297
return type_;
7398
}
7499

100+
Var::IValueType Var::ivalue_type() const {
101+
return ivalue_type_;
102+
}
103+
75104
std::string Var::type_name() const {
76105
switch (type_) {
77106
case Type::kITensor:
@@ -147,7 +176,39 @@ bool Var::isITensor() const {
147176
}
148177

149178
bool Var::isITensorList() const {
150-
if (type_ == Type::kITensor) {
179+
if (ivalue_type_ == IValueType::kITensorList) {
180+
return true;
181+
} else {
182+
return false;
183+
}
184+
}
185+
186+
bool Var::isIntList() const {
187+
if (ivalue_type_ == IValueType::kIntList) {
188+
return true;
189+
} else {
190+
return false;
191+
}
192+
}
193+
194+
bool Var::isDoubleList() const {
195+
if (ivalue_type_ == IValueType::kDoubleList) {
196+
return true;
197+
} else {
198+
return false;
199+
}
200+
}
201+
202+
bool Var::isTensorList() const {
203+
if (ivalue_type_ == IValueType::kTensorList) {
204+
return true;
205+
} else {
206+
return false;
207+
}
208+
}
209+
210+
bool Var::isBoolList() const {
211+
if (ivalue_type_ == IValueType::kBoolList) {
151212
return true;
152213
} else {
153214
return false;
@@ -157,6 +218,8 @@ bool Var::isITensorList() const {
157218
std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
158219
TORCHTRT_CHECK(
159220
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
221+
LOG_DEBUG(" === Is INT list: " << ptr_.ivalue->isIntList());
222+
LOG_DEBUG(" === Is List: " << ptr_.ivalue->isList());
160223
auto ivalue_list = ptr_.ivalue->toList();
161224
std::vector<nvinfer1::ITensor*> outputs;
162225
for (int i = 0; i < ivalue_list.size(); i++) {

core/conversion/var/Var.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace conversion {
1515
class Var : torch::CustomClassHolder {
1616
public:
1717
enum Type { kITensor, kIValue, kNone };
18-
18+
enum IValueType { kInt, kDouble, kBool, kTensor, kIntList, kDoubleList, kBoolList, kTensorList, kITensorList };
1919
Var();
2020
Var(torch::jit::IValue* p);
2121
Var(nvinfer1::ITensor* p);
@@ -60,9 +60,15 @@ class Var : torch::CustomClassHolder {
6060
bool isIValue() const;
6161
bool isITensor() const;
6262
bool isITensorList() const;
63+
bool isTensorList() const;
64+
bool isDoubleList() const;
65+
bool isIntList() const;
66+
bool isBoolList() const;
6367
bool isNone() const;
6468
Var::Type type() const;
69+
Var::IValueType ivalue_type() const;
6570
std::string type_name() const;
71+
Var::IValueType determineIValueType(torch::jit::IValue* p);
6672

6773
private:
6874
union VarContainer {
@@ -73,6 +79,7 @@ class Var : torch::CustomClassHolder {
7379

7480
VarContainer ptr_;
7581
Type type_;
82+
IValueType ivalue_type_;
7683
};
7784

7885
} // namespace conversion

tests/cpp/test_dynamic_size.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "tests/util/util.h"
66
#include "torch/csrc/jit/ir/irparser.h"
77

8-
TEST(Converters, ATenResizeDynamicInputCorrectly) {
8+
TEST(Converters, ATenResizeDynamicShapeCorrectly) {
99
const auto graph = R"IR(
1010
graph(%x : Tensor):
1111
%3 : int = prim::Constant[value=0]()
@@ -33,3 +33,29 @@ TEST(Converters, ATenResizeDynamicInputCorrectly) {
3333

3434
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
3535
}
36+
37+
TEST(Converters, ATenResizeDynamicInputCorrectly) {
38+
const auto graph = R"IR(
39+
graph(%x : Tensor):
40+
%2 : int[] = prim::Constant[value=[-1, 4, 64]]()
41+
%3 : Tensor = aten::reshape(%x, %2)
42+
return (%3))IR";
43+
44+
auto g = std::make_shared<torch::jit::Graph>();
45+
46+
torch::jit::parseIR(graph, g.get());
47+
48+
auto in = at::randint(1, 10, {16, 16, 16}, {at::kCUDA});
49+
50+
auto jit_in = at::clone(in);
51+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
52+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});
53+
54+
auto trt_in = at::clone(in);
55+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
56+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);
57+
58+
auto trt = trt_results[0].reshape(jit_results[0].sizes());
59+
60+
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
61+
}

0 commit comments

Comments
 (0)