Skip to content

Commit 34e7b09

Browse files
committed
refactor: Implement a macro for pytorch type checking
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 1d9200d commit 34e7b09

File tree

4 files changed

+38
-77
lines changed

4 files changed

+38
-77
lines changed

core/conversion/converters/impl/shuffle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
8686
shape_tensor = tensor_to_const(ctx, torch::tensor(shape_vec).to(torch::kI32));
8787
} else {
8888
LOG_ERROR(
89-
"Invalid IValue type of " << args[1].ivalue_type()
89+
"Invalid IValue type of " << args[1].IValue()->type()
9090
<< " detected for shape tensor from node: " << *n);
9191
}
9292
} else {

core/conversion/var/Var.cpp

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,6 @@ Var::Var(nvinfer1::ITensor* p) : type_(Type::kITensor) {
2121
ptr_.tensor = p;
2222
}
2323

24-
Var::IValueType Var::determineIValueType(torch::jit::IValue* p) {
25-
if (p->isInt()) {
26-
return IValueType::kInt;
27-
} else if (p->isDouble()) {
28-
return IValueType::kDouble;
29-
} else if (p->isBool()) {
30-
return IValueType::kBool;
31-
} else if (p->isTensor()) {
32-
return IValueType::kTensor;
33-
} else if (p->isIntList()) {
34-
return IValueType::kIntList;
35-
} else if (p->isDoubleList()) {
36-
return IValueType::kDoubleList;
37-
} else if (p->isBoolList()) {
38-
return IValueType::kBoolList;
39-
} else if (p->isTensorList()) {
40-
return IValueType::kTensorList;
41-
} else if (p->isList()) {
42-
return IValueType::kITensorList;
43-
}
44-
}
45-
4624
Var::Var(const Var& a) {
4725
switch (a.type_) {
4826
case Type::kITensor:
@@ -52,7 +30,6 @@ Var::Var(const Var& a) {
5230
case Type::kIValue:
5331
ptr_.ivalue = a.ptr_.ivalue;
5432
type_ = Type::kIValue;
55-
ivalue_type_ = determineIValueType(ptr_.ivalue);
5633
break;
5734
case Type::kNone:
5835
default:
@@ -70,7 +47,6 @@ Var& Var::operator=(const Var& a) {
7047
case Type::kIValue:
7148
ptr_.ivalue = a.ptr_.ivalue;
7249
type_ = Type::kIValue;
73-
ivalue_type_ = determineIValueType(ptr_.ivalue);
7450
break;
7551
case Type::kNone:
7652
default:
@@ -83,7 +59,6 @@ Var& Var::operator=(const Var& a) {
8359
Var& Var::operator=(torch::jit::IValue* in) {
8460
ptr_.ivalue = in;
8561
type_ = Type::kIValue;
86-
ivalue_type_ = determineIValueType(ptr_.ivalue);
8762
return (*this);
8863
}
8964

@@ -97,10 +72,6 @@ Var::Type Var::type() const {
9772
return type_;
9873
}
9974

100-
Var::IValueType Var::ivalue_type() const {
101-
return ivalue_type_;
102-
}
103-
10475
std::string Var::type_name() const {
10576
switch (type_) {
10677
case Type::kITensor:
@@ -175,40 +146,8 @@ bool Var::isITensor() const {
175146
}
176147
}
177148

178-
bool Var::isITensorList() const {
179-
if (ivalue_type_ == IValueType::kITensorList) {
180-
return true;
181-
} else {
182-
return false;
183-
}
184-
}
185-
186-
bool Var::isIntList() const {
187-
if (ivalue_type_ == IValueType::kIntList) {
188-
return true;
189-
} else {
190-
return false;
191-
}
192-
}
193-
194-
bool Var::isDoubleList() const {
195-
if (ivalue_type_ == IValueType::kDoubleList) {
196-
return true;
197-
} else {
198-
return false;
199-
}
200-
}
201-
202-
bool Var::isTensorList() const {
203-
if (ivalue_type_ == IValueType::kTensorList) {
204-
return true;
205-
} else {
206-
return false;
207-
}
208-
}
209-
210-
bool Var::isBoolList() const {
211-
if (ivalue_type_ == IValueType::kBoolList) {
149+
bool Var::isITensorList() {
150+
if (isList() && ptr_.ivalue->isCustomClass()) {
212151
return true;
213152
} else {
214153
return false;
@@ -218,10 +157,7 @@ bool Var::isBoolList() const {
218157
std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList() {
219158
TORCHTRT_CHECK(
220159
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
221-
TORCHTRT_CHECK(
222-
isITensorList(),
223-
"Expected IValue to be an ITensorList, however the type is "
224-
<< static_cast<std::underlying_type<IValueType>::type>(ivalue_type_));
160+
TORCHTRT_CHECK(isITensorList(), "Expected IValue to be an ITensorList");
225161
auto ivalue_list = ptr_.ivalue->toList();
226162
std::vector<nvinfer1::ITensor*> outputs;
227163
for (int i = 0; i < ivalue_list.size(); i++) {

core/conversion/var/Var.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace conversion {
1515
class Var : torch::CustomClassHolder {
1616
public:
1717
enum Type { kITensor, kIValue, kNone };
18-
enum IValueType { kInt, kDouble, kBool, kTensor, kIntList, kDoubleList, kBoolList, kTensorList, kITensorList };
18+
1919
Var();
2020
Var(torch::jit::IValue* p);
2121
Var(nvinfer1::ITensor* p);
@@ -59,16 +59,23 @@ class Var : torch::CustomClassHolder {
5959

6060
bool isIValue() const;
6161
bool isITensor() const;
62-
bool isITensorList() const;
63-
bool isTensorList() const;
64-
bool isDoubleList() const;
65-
bool isIntList() const;
66-
bool isBoolList() const;
6762
bool isNone() const;
63+
64+
bool isInt();
65+
bool isDouble();
66+
bool isBool();
67+
bool isString();
68+
bool isScalar();
69+
bool isTensor();
70+
bool isIntList();
71+
bool isDoubleList();
72+
bool isBoolList();
73+
bool isTensorList();
74+
bool isITensorList();
75+
bool isList();
76+
6877
Var::Type type() const;
69-
Var::IValueType ivalue_type() const;
7078
std::string type_name() const;
71-
Var::IValueType determineIValueType(torch::jit::IValue* p);
7279

7380
private:
7481
union VarContainer {
@@ -79,7 +86,6 @@ class Var : torch::CustomClassHolder {
7986

8087
VarContainer ptr_;
8188
Type type_;
82-
IValueType ivalue_type_;
8389
};
8490

8591
} // namespace conversion

core/conversion/var/Var_inl.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ namespace torch_tensorrt {
44
namespace core {
55
namespace conversion {
66

7+
#define DEFINE_IS_IVAL_TYPE(method_variant) \
8+
inline bool Var::is##method_variant() { \
9+
TORCHTRT_CHECK( \
10+
isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name()); \
11+
return ptr_.ivalue->is##method_variant(); \
12+
}
13+
714
#define DEFINE_UNWRAP_TO(ival_type, method_variant) \
815
template <> \
916
inline ival_type Var::unwrapTo<ival_type>() { \
@@ -34,6 +41,18 @@ namespace conversion {
3441
return this->unwrapTo<ival_type>(); \
3542
}
3643

44+
DEFINE_IS_IVAL_TYPE(Int)
45+
DEFINE_IS_IVAL_TYPE(Double)
46+
DEFINE_IS_IVAL_TYPE(Bool)
47+
DEFINE_IS_IVAL_TYPE(String)
48+
DEFINE_IS_IVAL_TYPE(Scalar)
49+
DEFINE_IS_IVAL_TYPE(Tensor)
50+
DEFINE_IS_IVAL_TYPE(IntList)
51+
DEFINE_IS_IVAL_TYPE(DoubleList)
52+
DEFINE_IS_IVAL_TYPE(BoolList)
53+
DEFINE_IS_IVAL_TYPE(TensorList)
54+
DEFINE_IS_IVAL_TYPE(List)
55+
3756
DEFINE_UNWRAP_TO(at::Tensor, Tensor)
3857
DEFINE_UNWRAP_TO(int64_t, Int)
3958
DEFINE_UNWRAP_TO(double, Double)

0 commit comments

Comments
 (0)