@@ -21,28 +21,6 @@ Var::Var(nvinfer1::ITensor* p) : type_(Type::kITensor) {
21
21
ptr_.tensor = p;
22
22
}
23
23
24
- Var::IValueType Var::determineIValueType (torch::jit::IValue* p) {
25
- if (p->isInt ()) {
26
- return IValueType::kInt ;
27
- } else if (p->isDouble ()) {
28
- return IValueType::kDouble ;
29
- } else if (p->isBool ()) {
30
- return IValueType::kBool ;
31
- } else if (p->isTensor ()) {
32
- return IValueType::kTensor ;
33
- } else if (p->isIntList ()) {
34
- return IValueType::kIntList ;
35
- } else if (p->isDoubleList ()) {
36
- return IValueType::kDoubleList ;
37
- } else if (p->isBoolList ()) {
38
- return IValueType::kBoolList ;
39
- } else if (p->isTensorList ()) {
40
- return IValueType::kTensorList ;
41
- } else if (p->isList ()) {
42
- return IValueType::kITensorList ;
43
- }
44
- }
45
-
46
24
Var::Var (const Var& a) {
47
25
switch (a.type_ ) {
48
26
case Type::kITensor :
@@ -52,7 +30,6 @@ Var::Var(const Var& a) {
52
30
case Type::kIValue :
53
31
ptr_.ivalue = a.ptr_ .ivalue ;
54
32
type_ = Type::kIValue ;
55
- ivalue_type_ = determineIValueType (ptr_.ivalue );
56
33
break ;
57
34
case Type::kNone :
58
35
default :
@@ -70,7 +47,6 @@ Var& Var::operator=(const Var& a) {
70
47
case Type::kIValue :
71
48
ptr_.ivalue = a.ptr_ .ivalue ;
72
49
type_ = Type::kIValue ;
73
- ivalue_type_ = determineIValueType (ptr_.ivalue );
74
50
break ;
75
51
case Type::kNone :
76
52
default :
@@ -83,7 +59,6 @@ Var& Var::operator=(const Var& a) {
83
59
Var& Var::operator =(torch::jit::IValue* in) {
84
60
ptr_.ivalue = in;
85
61
type_ = Type::kIValue ;
86
- ivalue_type_ = determineIValueType (ptr_.ivalue );
87
62
return (*this );
88
63
}
89
64
@@ -97,10 +72,6 @@ Var::Type Var::type() const {
97
72
return type_;
98
73
}
99
74
100
- Var::IValueType Var::ivalue_type () const {
101
- return ivalue_type_;
102
- }
103
-
104
75
std::string Var::type_name () const {
105
76
switch (type_) {
106
77
case Type::kITensor :
@@ -175,40 +146,8 @@ bool Var::isITensor() const {
175
146
}
176
147
}
177
148
178
- bool Var::isITensorList () const {
179
- if (ivalue_type_ == IValueType::kITensorList ) {
180
- return true ;
181
- } else {
182
- return false ;
183
- }
184
- }
185
-
186
- bool Var::isIntList () const {
187
- if (ivalue_type_ == IValueType::kIntList ) {
188
- return true ;
189
- } else {
190
- return false ;
191
- }
192
- }
193
-
194
- bool Var::isDoubleList () const {
195
- if (ivalue_type_ == IValueType::kDoubleList ) {
196
- return true ;
197
- } else {
198
- return false ;
199
- }
200
- }
201
-
202
- bool Var::isTensorList () const {
203
- if (ivalue_type_ == IValueType::kTensorList ) {
204
- return true ;
205
- } else {
206
- return false ;
207
- }
208
- }
209
-
210
- bool Var::isBoolList () const {
211
- if (ivalue_type_ == IValueType::kBoolList ) {
149
+ bool Var::isITensorList () {
150
+ if (isList () && ptr_.ivalue ->isCustomClass ()) {
212
151
return true ;
213
152
} else {
214
153
return false ;
@@ -218,10 +157,7 @@ bool Var::isBoolList() const {
218
157
std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList () {
219
158
TORCHTRT_CHECK (
220
159
isIValue (), " Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
221
- TORCHTRT_CHECK (
222
- isITensorList (),
223
- " Expected IValue to be an ITensorList, however the type is "
224
- << static_cast <std::underlying_type<IValueType>::type>(ivalue_type_));
160
+ TORCHTRT_CHECK (isITensorList (), " Expected IValue to be an ITensorList" );
225
161
auto ivalue_list = ptr_.ivalue ->toList ();
226
162
std::vector<nvinfer1::ITensor*> outputs;
227
163
for (int i = 0 ; i < ivalue_list.size (); i++) {
0 commit comments